first commit

This commit is contained in:
2024-09-05 12:56:46 +08:00
commit 8fee98d39d
144 changed files with 2766 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
import torch
import torch.nn
import clip
import numpy as np
from tqdm import tqdm
def test(net, test_dataset, test_loader, device):
net.eval()
total_accuracy = 0.0
texts = test_dataset.tokens.to(device)
with torch.no_grad():
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
images = images.to(device)
logits_per_image, logits_per_text = net(images, texts)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
accuracy = np.sum(probs.argmax(1) == targets.numpy())
total_accuracy += accuracy
net.train()
return total_accuracy / len(test_dataset)