2024-10-24

torch gradcheck and pytest

Here's a fun fact: torch.autograd.gradcheck is not compatible with pytest due to the latter's assertion rewriting feature. Internally, gradcheck takes a copy of locals() and *splats it into another function. Unfortunately, pytest inserts various things related to assertion rewriting into locals(), and breaks gradcheck.

The easy solution is to run pytest without assertion rewriting with --assert=plain.

The slightly more tedious solution is to write your own gradcheck. Below is a basic example for a scalar function:

import torch

def gradcheck(f, x, eps=1e-06, atol=1e-05, rtol=0.001):
    """Assert that finite difference gradients of f wrt x match autograd"""

    # for convenenience, we work on flattened inputs:
    shape = x.shape
    x = x.detach().reshape(-1)

    def _f(flattened):
        return f(flattened.reshape(shape))

    x1 = torch.clone(x)
    x1.requires_grad = True
    reference = torch.autograd.grad(_f(x1), x1)[0]

    x = torch.clone(x)
    for i in range(x.shape[0]):
        # we do central differences: df ~ 1/eps (f(x+eps/2) - f(x-eps/2))
        x[i] += eps / 2  # x + eps/2
        fa = _f(x)

        x[i] -= eps  # x - eps/2
        fb = _f(x)

        d = (fa - fb) / eps
        torch.testing.assert_close(reference[i], d, rtol=rtol, atol=atol)

        x[i] += eps / 2  # back to x