一、连续性判断
连续性:底层数据的存储顺序与张量按行优先一维展开的元素顺序是否一致
代码
import torch
# 例1、连续性判断
x = torch.arange(6)
print(x)
print(x.is_contiguous()) # 是否连续(底层数据的存储顺序与张量按行优先一维展开的元素顺序是否一致)
print(x.stride()) # 返回张量的步长元组,即每个维度上的元素之间的跨度。(1,) 表示步长为1
print(x.storage()) # 返回张量的存储对象,即底层的数据存储对象
y = x.view(2, 3) # 重塑为2x3的矩阵
print(y)
print(y.is_contiguous()) # 是否连续,True
print(y.stride()) # 返回张量的步长元组,即每个维度上的元素之间的跨度。(3, 1) 表示第0维度(行)的步长为3,第1维度(列)的步长为1
print(y.storage()) # 返回张量的存储对象,即底层的数据存储对象
运行结果
tensor([0, 1, 2, 3, 4, 5])
True
(1,)
0
1
2
3
4
5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
tensor([[0, 1, 2],
[3, 4, 5]])
True
(3, 1)
0
1
2
3
4
5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
二、view 和 reshape 操作后的张量与之前的共享内存
代码
import torch
# 例2、共享内存测试
x = torch.arange(6)
y = x.view(2, 3) # 或 reshape
x[0] = 520 # 修改x,发现y也变了
print(x)
print(y)
运行结果
tensor([520, 1, 2, 3, 4, 5])
tensor([[520, 1, 2],
[ 3, 4, 5]])
三、对不连续的张量进行view重塑会报错
代码
# 例3、view操作非连续的会报错
x = torch.arange(6)
y = x.view(2, 3)
print(y)
z = y.t() # 转置
print(z)
print(z.storage())
print(z.is_contiguous()) # False,转置后不连续。因为实际存储的是 0 1 2 3 4 5 6 ,但是按行展开是 0 3 1 4 2 5
print(y.view(3, 2)) # y连续,可以重塑
# 问题
# print(z.view(2, 3)) # z不连续,不能重塑,会报错 RuntimeError: view size is not compatible with input tensor's size and stride
# 解决办法
z = z.contiguous() # 使z连续
print(z.is_contiguous())
print(z.view(2, 3)) # z连续,可以重塑
运行结果
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])
0
1
2
3
4
5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
False
tensor([[0, 1],
[2, 3],
[4, 5]])
True
tensor([[0, 3, 1],
[4, 2, 5]])
四、对不连续的张量进行reshape重塑不会报错
代码
# 例4、reshape操作非连续的会报错
x = torch.arange(6)
y = x.view(2, 3) # 重塑为 2*3的矩阵
print(y)
z = y.t() # 转置,变为 3*2的矩阵
print(z)
print(z.is_contiguous()) # False,转置后不连续。因为实际存储的是 0 1 2 3 4 5 6 ,但是按行展开是 0 3 1 4 2 5
# 问题
print(z.reshape(2, 3)) # 重塑为2*3的矩阵。reshape不会报错,view会
运行结果
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])
False
tensor([[0, 3, 1],
[4, 2, 5]])
五、clone解决view和reshape共享内存的问题,创建新的内存副本
代码
# 例5、clone共享内存测试
x = torch.arange(6)
y = x.view(2, 3) # view, reshape都是共享内存
z = y.clone() # 克隆,不共享内存
x[0] = 520 # 修改x,发现y也变了
print(x)
print(y) # y会因为x变而变
print(z) # z不变
运行结果
tensor([520, 1, 2, 3, 4, 5])
tensor([[520, 1, 2],
[ 3, 4, 5]])
tensor([[0, 1, 2],
[3, 4, 5]])
您可以选择一种方式赞助本站
支付宝扫一扫赞助
微信钱包扫描赞助
赏