Compute gradients in parallel
Here is part of my code:
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 1, bias=False)
self.linear2 = nn.Linear(1, 2, bias=False)
def forward(self, x):
z = self.linear1(x)
y_pred = self.linear2(z)
return y_pred, z
model = SimpleNet().cuda()
for epoch in range(1):
model.train()
for i, dt in enumerate(data.trn_dl):
optimizer.zero_grad()
output = model(dt[0])
loss2 = 0
for j in range(0,len(output[0])):
l1 = torch.autograd.grad(output[0][j][0], output[1], create_graph=True)[0][j]
l2 = torch.autograd.grad(output[0][j][1], output[1], create_graph=True)[0][j]
loss2 = loss2 + abs(torch.sqrt(l1**2+l2**2)-1)
loss1 = F.mse_loss(output[0], dt[1])
loss = loss1+loss2
loss.backward()
optimizer.step()
if epoch%100==0:
print(loss1,loss2,loss)
So I need the gradient of the output layer with respect to some node (this is a simple example, the real one has more layers in between), which I calculate using torch.autograd.grad(output[0][j][0], output[1], create_graph=True)[0][j]
. However the way I do it now requires that for loop, over each element of the batch which is very slow. Is there a way to take this gradient all at once for a batch? Thank you!
Topic batch-normalization loss-function gradient-descent
Category Data Science