Gradient passthough in PyTorch
I need to quantize the inputs, but the method (bucketize) I need to do so is indifferentiable. I can of course detach the tensor, but then I lose the flow of gradients to earlier weights. I guess the question is quite simple, how do you continue the flow of gradients when necessary. For example, using the following code ...
x = self.linear1(x)
min, max = int(x.min()), int(x.max())
bins = torch.linspace(min, max+1, 16)
x = torch.bucketize(x.detach(), bins) # forced to detach here
x = self.linear2(x)
I know it's possible, as it was done with VQVAE and the gradients still flow fine for their quantization purposes. But I checked several versions of VQVAE code and it doesn't seem logical how the gradients pass through the argmin method. It just seems to work regardless, which confuses me. I'm quite perplexed by this. I would be grateful for any help on this one. Thanks in advance.
Topic gradient pytorch backpropagation deep-learning neural-network
Category Data Science