first commit
This commit is contained in:
20
Lab/Lab4/code/tune-clip-in-stanford_cars/test.py
Normal file
20
Lab/Lab4/code/tune-clip-in-stanford_cars/test.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.nn
|
||||
import clip
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def test(net, test_dataset, test_loader, device):
|
||||
net.eval()
|
||||
total_accuracy = 0.0
|
||||
texts = test_dataset.tokens.to(device)
|
||||
with torch.no_grad():
|
||||
for index, (images, tokens, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
|
||||
images = images.to(device)
|
||||
logits_per_image, logits_per_text = net(images, texts)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
accuracy = np.sum(probs.argmax(1) == targets.numpy())
|
||||
total_accuracy += accuracy
|
||||
net.train()
|
||||
return total_accuracy / len(test_dataset)
|
||||
Reference in New Issue
Block a user