refactor code

This commit is contained in:
2023-10-10 13:45:54 +08:00
parent 03be6f74c5
commit f814db12ae
4 changed files with 356 additions and 316 deletions

View File

@@ -12,7 +12,7 @@ result1 = A - B
result2 = torch.sub(A, B)
# 方法3: 手动实现广播机制并作差
def mysub(a:torch.Tensor, b:torch.Tensor):
def my_sub(a:torch.Tensor, b:torch.Tensor):
if not (
(a.size(0) == 1 and b.size(1) == 1)
or
@@ -29,7 +29,7 @@ def mysub(a:torch.Tensor, b:torch.Tensor):
result[i, j] = A_broadcasted[i, j] - B_broadcasted[i, j]
return result
result3 = mysub(A, B)
result3 = my_sub(A, B)
print("方法1的结果:")
print(result1)