refactor code
This commit is contained in:
@@ -12,7 +12,7 @@ result1 = A - B
|
||||
result2 = torch.sub(A, B)
|
||||
|
||||
# 方法3: 手动实现广播机制并作差
|
||||
def mysub(a:torch.Tensor, b:torch.Tensor):
|
||||
def my_sub(a:torch.Tensor, b:torch.Tensor):
|
||||
if not (
|
||||
(a.size(0) == 1 and b.size(1) == 1)
|
||||
or
|
||||
@@ -29,7 +29,7 @@ def mysub(a:torch.Tensor, b:torch.Tensor):
|
||||
result[i, j] = A_broadcasted[i, j] - B_broadcasted[i, j]
|
||||
return result
|
||||
|
||||
result3 = mysub(A, B)
|
||||
result3 = my_sub(A, B)
|
||||
|
||||
print("方法1的结果:")
|
||||
print(result1)
|
||||
|
||||
Reference in New Issue
Block a user