# torch gradcheck and pytest

Here's a fun fact: `torch.autograd.gradcheck`

is not compatible with `pytest`

due to the latter's assertion rewriting feature. Internally, `gradcheck`

takes a copy of `locals()`

and `*splat`

s it into another function. Unfortunately, `pytest`

inserts various things related to assertion rewriting into `locals()`

, and breaks `gradcheck`

.

The easy solution is to run `pytest`

without assertion rewriting with `--assert=plain`

.

The slightly more tedious solution is to write your own `gradcheck`

. Below is a basic example for a scalar function:

```
import torch
def gradcheck(f, x, eps=1e-06, atol=1e-05, rtol=0.001):
"""Assert that finite difference gradients of f wrt x match autograd"""
# for convenenience, we work on flattened inputs:
shape = x.shape
x = x.detach().reshape(-1)
def _f(flattened):
return f(flattened.reshape(shape))
x1 = torch.clone(x)
x1.requires_grad = True
reference = torch.autograd.grad(_f(x1), x1)[0]
x = torch.clone(x)
for i in range(x.shape[0]):
# we do central differences: df ~ 1/eps (f(x+eps/2) - f(x-eps/2))
x[i] += eps / 2 # x + eps/2
fa = _f(x)
x[i] -= eps # x - eps/2
fb = _f(x)
d = (fa - fb) / eps
torch.testing.assert_close(reference[i], d, rtol=rtol, atol=atol)
x[i] += eps / 2 # back to x
```