first commit
This commit is contained in:
73
Lab/Lab4/code/tune-clip-in-cub/train.py
Normal file
73
Lab/Lab4/code/tune-clip-in-cub/train.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
import clip
|
||||
|
||||
from get_loader import MyDataset
|
||||
from test import test
|
||||
|
||||
|
||||
def convert_models_to_fp32(model):
|
||||
for p in model.parameters():
|
||||
p.data = p.data.float()
|
||||
p.grad.data = p.grad.data.float()
|
||||
|
||||
|
||||
def train():
|
||||
batch_size = 64
|
||||
learning_rate = 1e-6
|
||||
num_epochs = 500
|
||||
|
||||
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
||||
net, preprocess = clip.load("ViT-L/14", device=device, jit=False)
|
||||
|
||||
if device == 'cpu':
|
||||
net.float()
|
||||
else:
|
||||
clip.model.convert_weights(net)
|
||||
|
||||
loss_img = nn.CrossEntropyLoss()
|
||||
loss_txt = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
|
||||
|
||||
train_dateset = MyDataset(processor=preprocess, train=True)
|
||||
train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=64, pin_memory=True)
|
||||
test_dataset = MyDataset(processor=preprocess, train=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True)
|
||||
|
||||
print(f'Train dataset size: {len(train_dateset)}\nTest dataset size: {len(test_dataset)}\n')
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
total_epoch_loss = 0
|
||||
for index, (images, tokens, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||
optimizer.zero_grad()
|
||||
images = images.to(device)
|
||||
tokens = tokens.to(device)
|
||||
with torch.set_grad_enabled(True):
|
||||
logits_per_image, logits_per_text = net(images, tokens)
|
||||
ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
|
||||
cur_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
|
||||
total_epoch_loss += cur_loss.item()
|
||||
cur_loss.backward()
|
||||
|
||||
if device == 'cpu':
|
||||
optimizer.step()
|
||||
else:
|
||||
convert_models_to_fp32(net)
|
||||
optimizer.step()
|
||||
clip.model.convert_weights(net)
|
||||
|
||||
test_acc = test(net, test_dataset, test_loader, device)
|
||||
print(f'Total train loss: {total_epoch_loss:.6f}, Test accuracy: {test_acc:.6%}')
|
||||
print("--------------------------------------------------------------")
|
||||
torch.save({'epoch': epoch,
|
||||
'model_state_dict': net.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': total_epoch_loss,
|
||||
}, f"model_checkpoint/model-{epoch + 1}_acc-{test_acc*100:.3f}.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user