PyTorch入门学习 11-梯度计算3,中断梯度追踪,想修改Tensor值不影响梯度

avatar 2024年04月16日11:28:38 0 775 views
博主分享免费Java教学视频,B站账号:Java刘哥 ,长期提供技术问题解决、项目定制:本站商品点此

一、正常情况,默认会自动计算梯度

代码

import torch

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)  # 定义一个张量,需要求导。里面分别为x1, x2, x3, x4
y1 = x ** 2  # 即 y = x^2
y2 = x ** 3  # 即 y = x^3
z = y1 + y2  # z = x^2 + x^3

print(x.requires_grad)
print(y1, y1.requires_grad)  # True 代表需要求导(计算梯度)
print(y2, y2.requires_grad)  # True 代表需要求导(计算梯度)
print(z, z.requires_grad)  # True 代表需要求导(计算梯度)
z.backward(torch.ones_like(z))
print(x.grad)  # 输出梯度,即dz/dx。 dz/dx = 2x+3x^2,当x=1,梯度为5

 

运行结果

True
tensor([[ 1.,  4.],
        [ 9., 16.]], grad_fn=<PowBackward0>) True
tensor([[ 1.,  8.],
        [27., 64.]], grad_fn=<PowBackward0>) True
tensor([[ 2., 12.],
        [36., 80.]], grad_fn=<AddBackward0>) True
tensor([[ 5., 16.],
        [33., 56.]])

 

二、中断其中一个梯度

由于 z = y1 + y2 

dz / dx = d(y1)/dx + d(y2)/dx

如果我们不想计算y2的梯度,最终 dz/dx = d(y1)/dx

 

代码

import torch

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)  # 定义一个张量,需要求导。里面分别为x1, x2, x3, x4
y1 = x ** 2  # 即 y = x^2
with torch.no_grad():  # y2不需要求导,所以不会记录计算图,不会求导
    y2 = x ** 3  # 即 y = x^3
z = y1 + y2  # z = x^2 + x^3

print(x.requires_grad)
print(y1, y1.requires_grad)  # True 代表需要求导(计算梯度)
print(y2, y2.requires_grad)  # False 代表需要求导(计算梯度)
print(z, z.requires_grad)  # True 代表需要求导(计算梯度)
z.backward(torch.ones_like(z))
print(x.grad)  # 输出梯度,即dz/dx。由于y2没有求导,故 dz/dx = 2x

 

运行结果

True
tensor([[ 1.,  4.],
        [ 9., 16.]], grad_fn=<PowBackward0>) True
tensor([[ 1.,  8.],
        [27., 64.]]) False
tensor([[ 2., 12.],
        [36., 80.]], grad_fn=<AddBackward0>) True
tensor([[2., 4.],
        [6., 8.]])

 

三、修改张量里的值,但不影响梯度

代码

import torch

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)  # 定义一个张量,需要求导。里面分别为x1, x2, x3, x4
y = 2 * x
x.data += 100  # 只改变data的值,不会影响梯度

y.backward(torch.ones_like(y))
print(x.grad)  # 输出梯度,即dy/dx。dy/dx = 2
print(x)

# 注意,如果上面是 y = x^2这种,dy/dx = 2*x, 则下面修改x的值,会影响梯度

 

运行结果

tensor([[202., 204.],
        [206., 208.]])
tensor([[101., 102.],
        [103., 104.]], requires_grad=True)

 

 

  • 微信
  • 交流学习,资料分享
  • weinxin
  • 个人淘宝
  • 店铺名:言曌博客咨询部

  • (部分商品未及时上架淘宝)
avatar

发表评论

avatar 登录者:匿名
匿名评论,评论回复后会有邮件通知

  

已通过评论:0   待审核评论数:0