first commit
This commit is contained in:
57
Lab/Lab4/code/tune-clip-in-cub/get_loader.py
Normal file
57
Lab/Lab4/code/tune-clip-in-cub/get_loader.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import clip
|
||||
|
||||
|
||||
class Classes:
|
||||
def __init__(self, classes_file):
|
||||
self.class2index = {}
|
||||
self.index2class = {}
|
||||
classes = open(classes_file).readlines()
|
||||
classes = [line.strip() for line in classes]
|
||||
for row in classes:
|
||||
index, birdname = row.split(' ')
|
||||
index = int(index)
|
||||
birdname = (birdname.split('.'))[1].replace('_', ' ')
|
||||
self.class2index['A photo of ' + birdname] = index - 1
|
||||
self.index2class[index - 1] = 'A photo of ' + birdname
|
||||
def __len__(self):
|
||||
return len(self.class2index)
|
||||
def get_class(self, num: int):
|
||||
return self.index2class[num] if (num in self.index2class) else None
|
||||
def get_id(self, class_name: str):
|
||||
return (
|
||||
self.class2index[class_name] if (class_name in self.class2index) else None
|
||||
)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, processor, train=True):
|
||||
classes = Classes('/home/kejingfan/cub/classes.txt')
|
||||
class_list = [classes.get_class(i) for i in range(len(classes))]
|
||||
self.tokens = clip.tokenize(class_list)
|
||||
|
||||
self.img_process = processor
|
||||
self.root_dir = '/home/kejingfan/cub/images'
|
||||
images_list = open('/home/kejingfan/cub/images.txt').readlines()
|
||||
images_list = [line.strip().split(' ')[1] for line in images_list]
|
||||
self.images = []
|
||||
labels_file = open('/home/kejingfan/cub/image_class_labels.txt').readlines()
|
||||
labels = [int(line.strip().split(' ')[1]) for line in labels_file]
|
||||
train_test_split_file = open('/home/kejingfan/cub/train_test_split.txt').readlines()
|
||||
is_train = [line.strip().split(' ')[1] == '1' for line in train_test_split_file]
|
||||
for index in range(len(images_list)):
|
||||
class_id = labels[index]
|
||||
if (train and is_train[index]) or (not train and not is_train[index]):
|
||||
self.images.append([os.path.join(self.root_dir, images_list[index]), int(class_id) - 1])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image, target = self.images[index]
|
||||
token = self.tokens[target]
|
||||
image = Image.open(image).convert("RGB")
|
||||
image = self.img_process(image)
|
||||
return image, token, target
|
||||
22
Lab/Lab4/code/tune-clip-in-cub/test.py
Normal file
22
Lab/Lab4/code/tune-clip-in-cub/test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import torch.nn
|
||||
import clip
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from get_loader import Classes
|
||||
|
||||
|
||||
def test(net, test_dataset, test_loader, device):
|
||||
net.eval()
|
||||
total_accuracy = 0.0
|
||||
texts = test_dataset.tokens.to(device)
|
||||
with torch.no_grad():
|
||||
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
|
||||
images = images.to(device)
|
||||
logits_per_image, logits_per_text = net(images, texts)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
accuracy = np.sum(probs.argmax(1) == targets.numpy())
|
||||
total_accuracy += accuracy
|
||||
net.train()
|
||||
return total_accuracy / len(test_dataset)
|
||||
73
Lab/Lab4/code/tune-clip-in-cub/train.py
Normal file
73
Lab/Lab4/code/tune-clip-in-cub/train.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
import clip
|
||||
|
||||
from get_loader import MyDataset
|
||||
from test import test
|
||||
|
||||
|
||||
def convert_models_to_fp32(model):
|
||||
for p in model.parameters():
|
||||
p.data = p.data.float()
|
||||
p.grad.data = p.grad.data.float()
|
||||
|
||||
|
||||
def train():
|
||||
batch_size = 64
|
||||
learning_rate = 1e-6
|
||||
num_epochs = 500
|
||||
|
||||
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
||||
net, preprocess = clip.load("ViT-L/14", device=device, jit=False)
|
||||
|
||||
if device == 'cpu':
|
||||
net.float()
|
||||
else:
|
||||
clip.model.convert_weights(net)
|
||||
|
||||
loss_img = nn.CrossEntropyLoss()
|
||||
loss_txt = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
|
||||
|
||||
train_dateset = MyDataset(processor=preprocess, train=True)
|
||||
train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=64, pin_memory=True)
|
||||
test_dataset = MyDataset(processor=preprocess, train=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True)
|
||||
|
||||
print(f'Train dataset size: {len(train_dateset)}\nTest dataset size: {len(test_dataset)}\n')
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
total_epoch_loss = 0
|
||||
for index, (images, tokens, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||
optimizer.zero_grad()
|
||||
images = images.to(device)
|
||||
tokens = tokens.to(device)
|
||||
with torch.set_grad_enabled(True):
|
||||
logits_per_image, logits_per_text = net(images, tokens)
|
||||
ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
|
||||
cur_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
|
||||
total_epoch_loss += cur_loss.item()
|
||||
cur_loss.backward()
|
||||
|
||||
if device == 'cpu':
|
||||
optimizer.step()
|
||||
else:
|
||||
convert_models_to_fp32(net)
|
||||
optimizer.step()
|
||||
clip.model.convert_weights(net)
|
||||
|
||||
test_acc = test(net, test_dataset, test_loader, device)
|
||||
print(f'Total train loss: {total_epoch_loss:.6f}, Test accuracy: {test_acc:.6%}')
|
||||
print("--------------------------------------------------------------")
|
||||
torch.save({'epoch': epoch,
|
||||
'model_state_dict': net.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': total_epoch_loss,
|
||||
}, f"model_checkpoint/model-{epoch + 1}_acc-{test_acc*100:.3f}.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
51
Lab/Lab4/code/tune-clip-in-stanford_cars/get_loader.py
Normal file
51
Lab/Lab4/code/tune-clip-in-stanford_cars/get_loader.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import pandas as pd
|
||||
import clip
|
||||
|
||||
|
||||
class Classes:
|
||||
def __init__(self, classes_file):
|
||||
self.class2index = {}
|
||||
self.index2class = {}
|
||||
classes = pd.read_csv(classes_file)
|
||||
for index, row in classes.iterrows():
|
||||
carname = row['class_names']
|
||||
self.class2index['A photo of ' + carname] = index
|
||||
self.index2class[index] = 'A photo of ' + carname
|
||||
def __len__(self):
|
||||
return len(self.class2index)
|
||||
def get_class(self, num: int):
|
||||
return self.index2class[num] if (num in self.index2class) else None
|
||||
def get_id(self, class_name: str):
|
||||
return (
|
||||
self.class2index[class_name] if (class_name in self.class2index) else None
|
||||
)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, processor, train=True):
|
||||
classes = Classes('/home/kejingfan/cars/class_names.csv')
|
||||
class_list = [classes.get_class(i) for i in range(len(classes))]
|
||||
self.tokens = clip.tokenize(class_list)
|
||||
|
||||
self.img_process = processor
|
||||
self.root_dir = '/home/kejingfan/cars' + ('/cars_' + ('train' if train else 'test')) * 2
|
||||
train_annos_file = '/home/kejingfan/cars/cars_train_annos.csv'
|
||||
test_annos_file = '/home/kejingfan/cars/cars_test_annos_withlabels.csv'
|
||||
images_list = pd.read_csv(train_annos_file if train else test_annos_file)
|
||||
self.images = []
|
||||
for index, row in images_list.iterrows():
|
||||
class_id = int(row['class'])
|
||||
self.images.append([os.path.join(self.root_dir, row['fname']), class_id - 1])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image, target = self.images[index]
|
||||
token = self.tokens[target]
|
||||
image = Image.open(image).convert("RGB")
|
||||
image = self.img_process(image)
|
||||
return image, token, target
|
||||
20
Lab/Lab4/code/tune-clip-in-stanford_cars/test.py
Normal file
20
Lab/Lab4/code/tune-clip-in-stanford_cars/test.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.nn
|
||||
import clip
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def test(net, test_dataset, test_loader, device):
|
||||
net.eval()
|
||||
total_accuracy = 0.0
|
||||
texts = test_dataset.tokens.to(device)
|
||||
with torch.no_grad():
|
||||
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
|
||||
images = images.to(device)
|
||||
logits_per_image, logits_per_text = net(images, texts)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
accuracy = np.sum(probs.argmax(1) == targets.numpy())
|
||||
total_accuracy += accuracy
|
||||
net.train()
|
||||
return total_accuracy / len(test_dataset)
|
||||
73
Lab/Lab4/code/tune-clip-in-stanford_cars/train.py
Normal file
73
Lab/Lab4/code/tune-clip-in-stanford_cars/train.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
import clip
|
||||
|
||||
from get_loader import MyDataset
|
||||
from test import test
|
||||
|
||||
|
||||
def convert_models_to_fp32(model):
|
||||
for p in model.parameters():
|
||||
p.data = p.data.float()
|
||||
p.grad.data = p.grad.data.float()
|
||||
|
||||
|
||||
def train():
|
||||
batch_size = 64
|
||||
learning_rate = 1e-6
|
||||
num_epochs = 500
|
||||
|
||||
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
|
||||
net, preprocess = clip.load("ViT-L/14", device=device, jit=False)
|
||||
|
||||
if device == 'cpu':
|
||||
net.float()
|
||||
else:
|
||||
clip.model.convert_weights(net)
|
||||
|
||||
loss_img = nn.CrossEntropyLoss()
|
||||
loss_txt = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
|
||||
|
||||
train_dateset = MyDataset(processor=preprocess, train=True)
|
||||
train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=64, pin_memory=True)
|
||||
test_dataset = MyDataset(processor=preprocess, train=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True)
|
||||
|
||||
print(f'Train dataset size: {len(train_dateset)}\nTest dataset size: {len(test_dataset)}\n')
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
total_epoch_loss = 0
|
||||
for index, (images, tokens, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||
optimizer.zero_grad()
|
||||
images = images.to(device)
|
||||
tokens = tokens.to(device)
|
||||
with torch.set_grad_enabled(True):
|
||||
logits_per_image, logits_per_text = net(images, tokens)
|
||||
ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
|
||||
cur_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
|
||||
total_epoch_loss += cur_loss.item()
|
||||
cur_loss.backward()
|
||||
|
||||
if device == 'cpu':
|
||||
optimizer.step()
|
||||
else:
|
||||
convert_models_to_fp32(net)
|
||||
optimizer.step()
|
||||
clip.model.convert_weights(net)
|
||||
|
||||
test_acc = test(net, test_dataset, test_loader, device)
|
||||
print(f'Total train loss: {total_epoch_loss:.6f}, Test accuracy: {test_acc:.6%}')
|
||||
print("--------------------------------------------------------------")
|
||||
torch.save({'epoch': epoch,
|
||||
'model_state_dict': net.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': total_epoch_loss,
|
||||
}, f"model_checkpoint/model-{epoch + 1}_acc-{test_acc*100:.3f}.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user