torch gradcheck and pytest
Here's a fun fact: torch.autograd.gradcheck
is not compatible with pytest
due to the latter's assertion rewriting feature. Internally, gradcheck
takes a copy of locals()
and *splat
s it into another function. Unfortunately, pytest
inserts various things related to assertion rewriting into locals()
, and breaks gradcheck
.
The easy solution is to run pytest
without assertion rewriting with --assert=plain
.
The slightly more tedious solution is to write your own gradcheck
. Below is a basic example for a scalar function:
import torch
def gradcheck(f, x, eps=1e-06, atol=1e-05, rtol=0.001):
"""Assert that finite difference gradients of f wrt x match autograd"""
# for convenenience, we work on flattened inputs:
shape = x.shape
x = x.detach().reshape(-1)
def _f(flattened):
return f(flattened.reshape(shape))
x1 = torch.clone(x)
x1.requires_grad = True
reference = torch.autograd.grad(_f(x1), x1)[0]
x = torch.clone(x)
for i in range(x.shape[0]):
# we do central differences: df ~ 1/eps (f(x+eps/2) - f(x-eps/2))
x[i] += eps / 2 # x + eps/2
fa = _f(x)
x[i] -= eps # x - eps/2
fb = _f(x)
d = (fa - fb) / eps
torch.testing.assert_close(reference[i], d, rtol=rtol, atol=atol)
x[i] += eps / 2 # back to x