代码
import torch
x = torch.arange(1, 5).view(2, 2)
# 1 2
# 3 4
print(x)
# 例1、求对角线之和
print(torch.trace(x)) # 求矩阵的迹(对角线之和)
# 例2、求对角线元素
print(torch.diag(x))
# 例3、求矩阵上三角/下三角,其他位置填0
print(torch.triu(x)) # 上三角
print(torch.tril(x)) # 下三角
# 例4、求矩阵的乘法
# mm() 矩阵相乘,针对二维矩阵
a = torch.Tensor([[1, 2], [3, 4]])
b = torch.Tensor([[5, 6], [7, 8]])
c = torch.mm(a, b) # 等价于 c = torch.matmul(a, b)
print(c, '\n')
# bmm() 批量矩阵相乘,针对三维矩阵
a = torch.Tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
b = torch.Tensor([[[5, 6], [7, 8]],
[[1, 2], [3, 4]]])
c = torch.bmm(a, b) # 等价于 c = torch.matmul(a, b)
print(c, '\n')
# 例5、矩阵运算
# addmm() 加法
a = torch.Tensor([[1, 2],
[3, 4]])
m1 = torch.Tensor([[1, 2],
[3, 4]])
m2 = torch.Tensor([[5, 6],
[7, 8]])
c = torch.addmm(a, m1, m2) # 等价于 c = torch.add(a, torch.mm(m1, m2))
# addbmm() 批量矩阵相乘,等价于 c = torch.add(a, torch.bmm(m1, m2))
# addmv() 矩阵和向量相乘,等价于 c = torch.add(a, torch.mv(m1, m2))
# 例6、矩阵向量相乘 (补充下mv操作)
a = torch.Tensor([[1, 2],
[3, 4]])
b = torch.Tensor([1, 2])
print(torch.mv(a, b)) # 矩阵 * 向量
# 例7、矩阵转置
a = torch.Tensor([[1, 2],
[3, 4]])
print(torch.t(a)) # 转置
# 例8、矩阵求逆
a = torch.Tensor([[1, 2],
[3, 4]])
print(torch.inverse(a)) # 求逆
# 例9、矩阵求行列式
a = torch.Tensor([[1, 2],
[3, 4]])
print(torch.det(a)) # 求行列式
运行结果
tensor([[1, 2],
[3, 4]])
tensor(5)
tensor([1, 4])
tensor([[1, 2],
[0, 4]])
tensor([[1, 0],
[3, 4]])
tensor([[19., 22.],
[43., 50.]])
tensor([[[19., 22.],
[43., 50.]],
[[23., 34.],
[31., 46.]]])
tensor([ 5., 11.])
tensor([[1., 3.],
[2., 4.]])
tensor([[-2.0000, 1.0000],
[ 1.5000, -0.5000]])
tensor(-2.)
您可以选择一种方式赞助本站
支付宝扫一扫赞助
微信钱包扫描赞助
赏