完成实验三
This commit is contained in:
@@ -1,14 +1,7 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.functional import *
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch import nn
|
||||
from torchvision import datasets, transforms
|
||||
from tqdm import tqdm
|
||||
from utils import *
|
||||
|
||||
import ipdb
|
||||
|
||||
class My_Dropout(nn.Module):
|
||||
def __init__(self, p, **kwargs):
|
||||
@@ -16,7 +9,7 @@ class My_Dropout(nn.Module):
|
||||
self.p = p
|
||||
self.mask = None
|
||||
|
||||
def forward(self, x:torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.training:
|
||||
self.mask = (torch.rand(x.shape) > self.p).to(dtype=torch.float32, device=x.device)
|
||||
return x * self.mask / (1 - self.p)
|
||||
@@ -27,7 +20,7 @@ class My_Dropout(nn.Module):
|
||||
if __name__ == "__main__":
|
||||
my_dropout = My_Dropout(p=0.5)
|
||||
nn_dropout = nn.Dropout(p=0.5)
|
||||
x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0, 9.0, 10.0]])
|
||||
print(f"输入:\n{x}")
|
||||
output_my_dropout = my_dropout(x)
|
||||
|
||||
Reference in New Issue
Block a user