Analysis for the minimum size for every tile of a grid
[1]:
NSIDE = 3
Imports
[2]:
import os
from functools import partial
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from numpyro.infer import MCMC, NUTS
from healpix_geometry_analysis.coordinates import HealpixCoordinates
from healpix_geometry_analysis.geometry.tile import TileGeometry
from healpix_geometry_analysis.problems.numpyro_sampler import NumpyroSamplerProblem
from healpix_geometry_analysis.problems.optax_optimizer import OptaxOptimizerProblem
from healpix_geometry_analysis.enable_x64 import enable_x64
enable_x64()
/home/docs/checkouts/readthedocs.org/user_builds/healpix-geometry-analysis/envs/latest/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Initializa a coordinate object, which knows few coordinate system transformations
[3]:
coord = HealpixCoordinates.from_nside(NSIDE)
Making a list of tiles
Equatorial region requires a tile per each Northern Hemisphere ring
[4]:
# Step between equatorial rings
delta_z = 2 / 3 / coord.grid.nside
# Step between meridian rings
delta_phi = 0.5 * jnp.pi / coord.grid.nside
# First, with longitude = 0
z_meridian = jnp.arange(1, coord.grid.nside - 2, 2) * delta_z
phi_meridian = jnp.zeros_like(z_meridian)
# Next, with a half-step over phi
z_offset = jnp.arange(0, coord.grid.nside - 1, 2) * delta_z
phi_offset = jnp.full_like(z_offset, 0.5 * delta_phi)
z_eq = jnp.concatenate([z_meridian, z_offset])
phi_eq = jnp.concatenate([phi_meridian, phi_offset])
k_eq, kp_eq = coord.diag_from_phi_z(phi_eq, z_eq)
Intermidiate region requires all tiles from ring z = 2/3 - delta_z
[5]:
phi_inter = jnp.arange(0, coord.grid.nside // 2 + 1) * delta_phi
z_inter = jnp.full_like(phi_inter, 2 / 3 - delta_z)
k_inter, kp_inter = coord.diag_from_phi_z(phi_inter, z_inter)
Polar region requires all tiles in 0 < lon <= pi/4, 2/3 <= z < 1
[6]:
# Use rectangular indices to define the tiles
# First, create a matrix of all possible pairs: we will filter it later
i_pol_ = jnp.arange(1, coord.grid.nside + 1)
j_pol_ = jnp.arange(0, coord.grid.nside)
i_pol_all, j_pol_all = jnp.meshgrid(i_pol_, j_pol_)
# Filter to have only j indices within a required "triangle"
j_pol_idx = j_pol_all <= (i_pol_all - 1) // 2
i_pol, j_pol = i_pol_all[j_pol_idx], j_pol_all[j_pol_idx]
# Get k & k'
k_pol = j_pol + 0.5
kp_pol = i_pol - j_pol - 0.5
Combine all diagonal indices and create geometry objects
[7]:
k = jnp.concatenate([k_eq, k_inter, k_pol])
kp = jnp.concatenate([kp_eq, kp_inter, kp_pol])
plt.scatter(*coord.phi_z(k, kp), s=10)
plt.xlabel(r"$\phi$")
plt.ylabel("$z$")
print(k.shape)
(7,)
Use NUTS sampler
[8]:
%%time
@partial(jax.vmap, in_axes=[None, 0, 0, None])
def solve_with_nuts(direction, k_c, kp_c, random_seed=0):
geometry = TileGeometry(
coord=coord,
k_center=k_c,
kp_center=kp_c,
direction=direction,
distance="chord_squared",
)
problem = NumpyroSamplerProblem(geometry, track_arc_length=True)
kernel = NUTS(problem.model)
mcmc = MCMC(kernel, num_warmup=0, num_samples=10_000, jit_model_args=True, progress_bar=False)
rng_key = jax.random.PRNGKey(random_seed)
mcmc.run(rng_key)
samples = mcmc.get_samples()
argmin = jnp.argmin(samples["distance"])
return jax.tree.map(lambda x: x[argmin], samples)
random_seeds = {"p": 1, "m": -1}
samples = {direction: solve_with_nuts(direction, k, kp, seed) for direction, seed in random_seeds.items()}
min_arc_length = min(float(jnp.min(samples["arc_length_degree"])) for samples in samples.values())
average_size = coord.grid.average_pixel_size_degree
ratio = min_arc_length / average_size
print(f"{min_arc_length = :.4f}, {average_size = : .4f} {ratio = : .4f}")
/home/docs/checkouts/readthedocs.org/user_builds/healpix-geometry-analysis/envs/latest/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1271: UserWarning: Some donated buffers were not usable: bool[10000].
See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
/home/docs/checkouts/readthedocs.org/user_builds/healpix-geometry-analysis/envs/latest/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1271: UserWarning: Some donated buffers were not usable: bool[10000].
See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
min_arc_length = 14.1443, average_size = 19.5441 ratio = 0.7237
CPU times: user 25 s, sys: 761 ms, total: 25.8 s
Wall time: 20.1 s
Use AdaBelief optimizer
[9]:
%%time
@partial(jax.vmap, in_axes=[None, 0, 0, None])
def solve_with_ada(direction, k_c, kp_c, random_seed=0):
geometry = TileGeometry(
coord=coord,
k_center=k_c,
kp_center=kp_c,
direction=direction,
distance="chord_squared",
)
problem = OptaxOptimizerProblem(geometry)
optimizer = problem.freeze_optimizer(optax.adabelief(1e-1))
rng_key = jax.random.PRNGKey(random_seed)
params = problem.initial_params(rng_key)
opt_state = optimizer.init(params)
for _ in range(100):
loss, grads = jax.value_and_grad(problem.loss)(params)
grads = jax.tree.map(lambda x: jnp.where(jnp.isfinite(x), x, 0.0), grads)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
params = optax.projections.projection_box(
params, problem.geometry.lower_bounds, problem.geometry.upper_bounds
)
arc_distance_deg = problem.geometry.arc_length_degrees(loss)
return arc_distance_deg
random_seeds = {"p": 1, "m": -1}
arc_distance_deg = jnp.concatenate(
[solve_with_ada(direction, k, kp, seed) for direction, seed in random_seeds.items()]
)
min_arc_length = jnp.min(arc_distance_deg)
average_size = coord.grid.average_pixel_size_degree
ratio = min_arc_length / average_size
print(f"{min_arc_length = :.4f}, {average_size = : .4f} {ratio = : .4f}")
min_arc_length = 14.0890, average_size = 19.5441 ratio = 0.7209
CPU times: user 23.6 s, sys: 31.4 ms, total: 23.6 s
Wall time: 23.5 s