enable_x64()
enable_x64
Make Jax to use double precision by default
It must be run before any other Jax code.