first commit

This commit is contained in:
2024-09-05 12:56:46 +08:00
commit 8fee98d39d
144 changed files with 2766 additions and 0 deletions

View File

@@ -0,0 +1,57 @@
from PIL import Image
from torch.utils.data import Dataset
import os
import clip
class Classes:
def __init__(self, classes_file):
self.class2index = {}
self.index2class = {}
classes = open(classes_file).readlines()
classes = [line.strip() for line in classes]
for row in classes:
index, birdname = row.split(' ')
index = int(index)
birdname = (birdname.split('.'))[1].replace('_', ' ')
self.class2index['A photo of ' + birdname] = index - 1
self.index2class[index - 1] = 'A photo of ' + birdname
def __len__(self):
return len(self.class2index)
def get_class(self, num: int):
return self.index2class[num] if (num in self.index2class) else None
def get_id(self, class_name: str):
return (
self.class2index[class_name] if (class_name in self.class2index) else None
)
class MyDataset(Dataset):
def __init__(self, processor, train=True):
classes = Classes('/home/kejingfan/cub/classes.txt')
class_list = [classes.get_class(i) for i in range(len(classes))]
self.tokens = clip.tokenize(class_list)
self.img_process = processor
self.root_dir = '/home/kejingfan/cub/images'
images_list = open('/home/kejingfan/cub/images.txt').readlines()
images_list = [line.strip().split(' ')[1] for line in images_list]
self.images = []
labels_file = open('/home/kejingfan/cub/image_class_labels.txt').readlines()
labels = [int(line.strip().split(' ')[1]) for line in labels_file]
train_test_split_file = open('/home/kejingfan/cub/train_test_split.txt').readlines()
is_train = [line.strip().split(' ')[1] == '1' for line in train_test_split_file]
for index in range(len(images_list)):
class_id = labels[index]
if (train and is_train[index]) or (not train and not is_train[index]):
self.images.append([os.path.join(self.root_dir, images_list[index]), int(class_id) - 1])
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image, target = self.images[index]
token = self.tokens[target]
image = Image.open(image).convert("RGB")
image = self.img_process(image)
return image, token, target

View File

@@ -0,0 +1,22 @@
import torch
import torch.nn
import clip
import numpy as np
from tqdm import tqdm
from get_loader import Classes
def test(net, test_dataset, test_loader, device):
net.eval()
total_accuracy = 0.0
texts = test_dataset.tokens.to(device)
with torch.no_grad():
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
images = images.to(device)
logits_per_image, logits_per_text = net(images, texts)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
accuracy = np.sum(probs.argmax(1) == targets.numpy())
total_accuracy += accuracy
net.train()
return total_accuracy / len(test_dataset)

View File

@@ -0,0 +1,73 @@
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import clip
from get_loader import MyDataset
from test import test
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
p.grad.data = p.grad.data.float()
def train():
batch_size = 64
learning_rate = 1e-6
num_epochs = 500
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("ViT-L/14", device=device, jit=False)
if device == 'cpu':
net.float()
else:
clip.model.convert_weights(net)
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
train_dateset = MyDataset(processor=preprocess, train=True)
train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=64, pin_memory=True)
test_dataset = MyDataset(processor=preprocess, train=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True)
print(f'Train dataset size: {len(train_dateset)}\nTest dataset size: {len(test_dataset)}\n')
for epoch in range(num_epochs):
total_epoch_loss = 0
for index, (images, tokens, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
optimizer.zero_grad()
images = images.to(device)
tokens = tokens.to(device)
with torch.set_grad_enabled(True):
logits_per_image, logits_per_text = net(images, tokens)
ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
cur_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
total_epoch_loss += cur_loss.item()
cur_loss.backward()
if device == 'cpu':
optimizer.step()
else:
convert_models_to_fp32(net)
optimizer.step()
clip.model.convert_weights(net)
test_acc = test(net, test_dataset, test_loader, device)
print(f'Total train loss: {total_epoch_loss:.6f}, Test accuracy: {test_acc:.6%}')
print("--------------------------------------------------------------")
torch.save({'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_epoch_loss,
}, f"model_checkpoint/model-{epoch + 1}_acc-{test_acc*100:.3f}.pt")
if __name__ == "__main__":
train()