2025-03-07

torch.nn.Linear initialisation

Note to self: The default initialisation for torch.nn.Linear is sampling from a uniform distribution in the interval [-1/fan_in**0.5, 1/fan_in**0.5]. For some reason, this is implemented by calling a He Kaiming initialiser with a somewhat mysterious factor sqrt(5) which, via a chain of calls gets transformed into the actual result.

Note that this requires extra care with jax, which multiplies a factor 3**0.5 to the interval size of the uniform distribution by default, so to reproduce this behaviour, one needs to call

jax.nn.variance_scaling(1.0/3.0, "fan_in", "uniform")