完成实验2

This commit is contained in:
2023-10-24 20:15:43 +08:00
parent 7e8fea0a3b
commit d76db395de
10 changed files with 739 additions and 161 deletions

View File

@@ -74,9 +74,37 @@ class Model_3_3(nn.Module):
x = self.fc3(x)
x = self.activate_fn(x)
return x
class Model_3_4(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(in_features=28 * 28, out_features=1024)
self.fc2 = nn.Linear(in_features=1024, out_features=256)
self.fc3 = nn.Linear(in_features=256, out_features=num_classes)
self.activate_fn = leaky_relu
def forward(self, x: torch.Tensor):
x = self.flatten(x)
x = self.fc1(x)
x = self.activate_fn(x)
x = self.fc2(x)
x = self.activate_fn(x)
x = self.fc3(x)
x = self.activate_fn(x)
return x
if __name__ == "__main__":
train_MNIST_CLS(Model=Model_3_1)
train_MNIST_CLS(Model=Model_3_2)
train_MNIST_CLS(Model=Model_3_3)
print("模型1开始训练激活函数为relu")
train_loss_3_1, test_acc_3_1 = train_MNIST_CLS(Model=Model_3_1) # 激活函数为relu
print("模型2开始训练激活函数为sigmoid")
train_loss_3_2, test_acc_3_2 = train_MNIST_CLS(Model=Model_3_2) # 激活函数为sigmoid
print("模型3开始训练激活函数为tanh")
train_loss_3_3, test_acc_3_3 = train_MNIST_CLS(Model=Model_3_3) # 激活函数为tanh
print("模型4开始训练激活函数为leaky_relu")
train_loss_3_4, test_acc_3_4 = train_MNIST_CLS(Model=Model_3_4) # 激活函数为leaky_relu