一、正常情况,默认会自动计算梯度
代码
import torch
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) # 定义一个张量,需要求导。里面分别为x1, x2, x3, x4
y1 = x ** 2 # 即 y = x^2
y2 = x ** 3 # 即 y = x^3
z = y1 + y2 # z = x^2 + x^3
print(x.requires_grad)
print(y1, y1.requires_grad) # True 代表需要求导(计算梯度)
print(y2, y2.requires_grad) # True 代表需要求导(计算梯度)
print(z, z.requires_grad) # True 代表需要求导(计算梯度)
z.backward(torch.ones_like(z))
print(x.grad) # 输出梯度,即dz/dx。 dz/dx = 2x+3x^2,当x=1,梯度为5
运行结果
True
tensor([[ 1., 4.],
[ 9., 16.]], grad_fn=<PowBackward0>) True
tensor([[ 1., 8.],
[27., 64.]], grad_fn=<PowBackward0>) True
tensor([[ 2., 12.],
[36., 80.]], grad_fn=<AddBackward0>) True
tensor([[ 5., 16.],
[33., 56.]])
二、中断其中一个梯度
由于 z = y1 + y2
dz / dx = d(y1)/dx + d(y2)/dx
如果我们不想计算y2的梯度,最终 dz/dx = d(y1)/dx
代码
import torch
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) # 定义一个张量,需要求导。里面分别为x1, x2, x3, x4
y1 = x ** 2 # 即 y = x^2
with torch.no_grad(): # y2不需要求导,所以不会记录计算图,不会求导
y2 = x ** 3 # 即 y = x^3
z = y1 + y2 # z = x^2 + x^3
print(x.requires_grad)
print(y1, y1.requires_grad) # True 代表需要求导(计算梯度)
print(y2, y2.requires_grad) # False 代表需要求导(计算梯度)
print(z, z.requires_grad) # True 代表需要求导(计算梯度)
z.backward(torch.ones_like(z))
print(x.grad) # 输出梯度,即dz/dx。由于y2没有求导,故 dz/dx = 2x
运行结果
True
tensor([[ 1., 4.],
[ 9., 16.]], grad_fn=<PowBackward0>) True
tensor([[ 1., 8.],
[27., 64.]]) False
tensor([[ 2., 12.],
[36., 80.]], grad_fn=<AddBackward0>) True
tensor([[2., 4.],
[6., 8.]])
三、修改张量里的值,但不影响梯度
代码
import torch
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) # 定义一个张量,需要求导。里面分别为x1, x2, x3, x4
y = 2 * x
x.data += 100 # 只改变data的值,不会影响梯度
y.backward(torch.ones_like(y))
print(x.grad) # 输出梯度,即dy/dx。dy/dx = 2
print(x)
# 注意,如果上面是 y = x^2这种,dy/dx = 2*x, 则下面修改x的值,会影响梯度
运行结果
tensor([[202., 204.],
[206., 208.]])
tensor([[101., 102.],
[103., 104.]], requires_grad=True)
您可以选择一种方式赞助本站
支付宝扫一扫赞助
微信钱包扫描赞助
赏