jax.nn.dot_product_attention
Fun fact: jax.nn.dot_product_attention somewhat stealthily casts things into float32. Bit of a footgun for double precision use.
It apears that flax does not do this.
Fun fact: jax.nn.dot_product_attention somewhat stealthily casts things into float32. Bit of a footgun for double precision use.
It apears that flax does not do this.