Does torch.cat work with backpropagation?
I was wondering if it was okay to use torch.cat within my forward function. I am doing so because I want the first two columns of my input to skip the middle hidden layers and go directly to the final layer.
Here is my code: you can see that I use torch.cat at the last moment to make xcat.
Does the gradient propagate back? or does the torch.cat cover up what happened to my hidden variables?
class LinearRegressionForce(nn.Module):
def __init__(self, focus_input_size, rest_input_size, hidden_size_1, hidden_size_2, output_size):
super(LinearRegressionForce, self).__init__()
self.in1 = nn.Linear(rest_input_size, hidden_size_1)
self.middle1 = nn.Linear(hidden_size_1,hidden_size_2)
self.out4 = nn.Linear(focus_input_size + hidden_size_2,output_size)
def forward(self, inputs):
focus_inputs = inputs[:,0:focus_input_size]
rest_inputs = inputs[:,focus_input_size:(rest_input_size+focus_input_size)]
x = self.in1(rest_inputs).clamp(min=0)
x = self.middle1(x).clamp(min=0)
xcat = torch.cat((focus_inputs,x),1)
out = self.out4(xcat).clamp(min=0)
return out
I call it like so:
rest_inputs = Variable(torch.from_numpy(rest_x_train))
focus_x_train_ones = np.concatenate((focus_x_train, np.ones((n,1))), axis=1)
focus_inputs = Variable(torch.from_numpy(focus_x_train_ones)).float()
inputs = torch.cat((focus_inputs,rest_inputs),1)
predicted = model(inputs).data.numpy()
Category Data Science