2026-04-23

jax.nn.dot_product_attention

Fun fact: jax.nn.dot_product_attention somewhat stealthily casts things into float32. Bit of a footgun for double precision use.

It apears that flax does not do this.