PyTorch入门学习 4-连续性、view, reshape, clone

avatar 2024年04月14日21:53:22 0 522 views
博主分享免费Java教学视频,B站账号:Java刘哥 ,长期提供技术问题解决、项目定制:本站商品点此

一、连续性判断

连续性:底层数据的存储顺序与张量按行优先一维展开的元素顺序是否一致

代码

import torch
# 例1、连续性判断
x = torch.arange(6)
print(x)
print(x.is_contiguous())  # 是否连续(底层数据的存储顺序与张量按行优先一维展开的元素顺序是否一致)
print(x.stride())  # 返回张量的步长元组,即每个维度上的元素之间的跨度。(1,) 表示步长为1
print(x.storage()) # 返回张量的存储对象,即底层的数据存储对象

y = x.view(2, 3)  # 重塑为2x3的矩阵
print(y)
print(y.is_contiguous())  # 是否连续,True
print(y.stride())  # 返回张量的步长元组,即每个维度上的元素之间的跨度。(3, 1) 表示第0维度(行)的步长为3,第1维度(列)的步长为1
print(y.storage()) # 返回张量的存储对象,即底层的数据存储对象

 

运行结果

tensor([0, 1, 2, 3, 4, 5])
True
(1,)
 0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
tensor([[0, 1, 2],
        [3, 4, 5]])
True
(3, 1)
 0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]

 

二、view 和 reshape 操作后的张量与之前的共享内存

代码

import torch
# 例2、共享内存测试
x = torch.arange(6)
y = x.view(2, 3)  # 或 reshape
x[0] = 520  # 修改x,发现y也变了
print(x)
print(y)

运行结果

tensor([520,   1,   2,   3,   4,   5])
tensor([[520,   1,   2],
        [  3,   4,   5]])

 

三、对不连续的张量进行view重塑会报错

代码

# 例3、view操作非连续的会报错
x = torch.arange(6)
y = x.view(2, 3)
print(y)
z = y.t()  # 转置
print(z)
print(z.storage())
print(z.is_contiguous())  # False,转置后不连续。因为实际存储的是 0 1 2 3 4 5 6 ,但是按行展开是 0 3 1 4 2 5
print(y.view(3, 2))  # y连续,可以重塑
# 问题
# print(z.view(2, 3))  # z不连续,不能重塑,会报错 RuntimeError: view size is not compatible with input tensor's size and stride

# 解决办法
z = z.contiguous() # 使z连续
print(z.is_contiguous())
print(z.view(2, 3)) # z连续,可以重塑

运行结果

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 3],
        [1, 4],
        [2, 5]])
 0
 1
 2
 3
 4
 5
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 6]
False
tensor([[0, 1],
        [2, 3],
        [4, 5]])
True
tensor([[0, 3, 1],
        [4, 2, 5]])

 

四、对不连续的张量进行reshape重塑不会报错

代码

# 例4、reshape操作非连续的会报错
x = torch.arange(6)
y = x.view(2, 3)  # 重塑为 2*3的矩阵
print(y)
z = y.t()  # 转置,变为 3*2的矩阵
print(z)
print(z.is_contiguous())  # False,转置后不连续。因为实际存储的是 0 1 2 3 4 5 6 ,但是按行展开是 0 3 1 4 2 5
# 问题
print(z.reshape(2, 3))  # 重塑为2*3的矩阵。reshape不会报错,view会

运行结果

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 3],
        [1, 4],
        [2, 5]])
False
tensor([[0, 3, 1],
        [4, 2, 5]])

 

五、clone解决view和reshape共享内存的问题,创建新的内存副本

代码

# 例5、clone共享内存测试
x = torch.arange(6)
y = x.view(2, 3)  # view, reshape都是共享内存
z = y.clone()  # 克隆,不共享内存
x[0] = 520  # 修改x,发现y也变了
print(x)
print(y)  # y会因为x变而变
print(z)  # z不变

运行结果

tensor([520,   1,   2,   3,   4,   5])
tensor([[520,   1,   2],
        [  3,   4,   5]])
tensor([[0, 1, 2],
        [3, 4, 5]])

 

 

 

 

 

  • 微信
  • 交流学习,资料分享
  • weinxin
  • 个人淘宝
  • 店铺名:言曌博客咨询部

  • (部分商品未及时上架淘宝)
avatar

发表评论

avatar 登录者:匿名
匿名评论,评论回复后会有邮件通知

  

已通过评论:0   待审核评论数:0