Analysis for the minimum size for every tile of a grid

[1]:
NSIDE = 3

Imports

[2]:
import os
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from numpyro.infer import MCMC, NUTS

from healpix_geometry_analysis.coordinates import HealpixCoordinates
from healpix_geometry_analysis.geometry.tile import TileGeometry
from healpix_geometry_analysis.problems.numpyro_sampler import NumpyroSamplerProblem
from healpix_geometry_analysis.problems.optax_optimizer import OptaxOptimizerProblem
from healpix_geometry_analysis.enable_x64 import enable_x64

enable_x64()
/home/docs/checkouts/readthedocs.org/user_builds/healpix-geometry-analysis/envs/latest/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Initializa a coordinate object, which knows few coordinate system transformations

[3]:
coord = HealpixCoordinates.from_nside(NSIDE)

Making a list of tiles

Equatorial region requires a tile per each Northern Hemisphere ring

[4]:
# Step between equatorial rings
delta_z = 2 / 3 / coord.grid.nside
# Step between meridian rings
delta_phi = 0.5 * jnp.pi / coord.grid.nside

# First, with longitude = 0
z_meridian = jnp.arange(1, coord.grid.nside - 2, 2) * delta_z
phi_meridian = jnp.zeros_like(z_meridian)
# Next, with a half-step over phi
z_offset = jnp.arange(0, coord.grid.nside - 1, 2) * delta_z
phi_offset = jnp.full_like(z_offset, 0.5 * delta_phi)

z_eq = jnp.concatenate([z_meridian, z_offset])
phi_eq = jnp.concatenate([phi_meridian, phi_offset])

k_eq, kp_eq = coord.diag_from_phi_z(phi_eq, z_eq)

Intermidiate region requires all tiles from ring z = 2/3 - delta_z

[5]:
phi_inter = jnp.arange(0, coord.grid.nside // 2 + 1) * delta_phi
z_inter = jnp.full_like(phi_inter, 2 / 3 - delta_z)

k_inter, kp_inter = coord.diag_from_phi_z(phi_inter, z_inter)

Polar region requires all tiles in 0 < lon <= pi/4, 2/3 <= z < 1

[6]:
# Use rectangular indices to define the tiles

# First, create a matrix of all possible pairs: we will filter it later
i_pol_ = jnp.arange(1, coord.grid.nside + 1)
j_pol_ = jnp.arange(0, coord.grid.nside)
i_pol_all, j_pol_all = jnp.meshgrid(i_pol_, j_pol_)

# Filter to have only j indices within a required "triangle"
j_pol_idx = j_pol_all <= (i_pol_all - 1) // 2
i_pol, j_pol = i_pol_all[j_pol_idx], j_pol_all[j_pol_idx]

# Get k & k'
k_pol = j_pol + 0.5
kp_pol = i_pol - j_pol - 0.5

Combine all diagonal indices and create geometry objects

[7]:
k = jnp.concatenate([k_eq, k_inter, k_pol])
kp = jnp.concatenate([kp_eq, kp_inter, kp_pol])

plt.scatter(*coord.phi_z(k, kp), s=10)
plt.xlabel(r"$\phi$")
plt.ylabel("$z$")

print(k.shape)
(7,)
../_images/notebooks_tile_min_size_1by1_14_1.png

Use NUTS sampler

[8]:
%%time


@partial(jax.vmap, in_axes=[None, 0, 0, None])
def solve_with_nuts(direction, k_c, kp_c, random_seed=0):
    geometry = TileGeometry(
        coord=coord,
        k_center=k_c,
        kp_center=kp_c,
        direction=direction,
        distance="chord_squared",
    )
    problem = NumpyroSamplerProblem(geometry, track_arc_length=True)

    kernel = NUTS(problem.model)
    mcmc = MCMC(kernel, num_warmup=0, num_samples=10_000, jit_model_args=True, progress_bar=False)
    rng_key = jax.random.PRNGKey(random_seed)
    mcmc.run(rng_key)

    samples = mcmc.get_samples()

    argmin = jnp.argmin(samples["distance"])
    return jax.tree.map(lambda x: x[argmin], samples)


random_seeds = {"p": 1, "m": -1}
samples = {direction: solve_with_nuts(direction, k, kp, seed) for direction, seed in random_seeds.items()}

min_arc_length = min(float(jnp.min(samples["arc_length_degree"])) for samples in samples.values())
average_size = coord.grid.average_pixel_size_degree
ratio = min_arc_length / average_size
print(f"{min_arc_length = :.4f}, {average_size = : .4f} {ratio = : .4f}")
/home/docs/checkouts/readthedocs.org/user_builds/healpix-geometry-analysis/envs/latest/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1271: UserWarning: Some donated buffers were not usable: bool[10000].
See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
/home/docs/checkouts/readthedocs.org/user_builds/healpix-geometry-analysis/envs/latest/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py:1271: UserWarning: Some donated buffers were not usable: bool[10000].
See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
min_arc_length = 14.1443, average_size =  19.5441 ratio =  0.7237
CPU times: user 25 s, sys: 761 ms, total: 25.8 s
Wall time: 20.1 s

Use AdaBelief optimizer

[9]:
%%time


@partial(jax.vmap, in_axes=[None, 0, 0, None])
def solve_with_ada(direction, k_c, kp_c, random_seed=0):
    geometry = TileGeometry(
        coord=coord,
        k_center=k_c,
        kp_center=kp_c,
        direction=direction,
        distance="chord_squared",
    )
    problem = OptaxOptimizerProblem(geometry)

    optimizer = problem.freeze_optimizer(optax.adabelief(1e-1))
    rng_key = jax.random.PRNGKey(random_seed)
    params = problem.initial_params(rng_key)
    opt_state = optimizer.init(params)

    for _ in range(100):
        loss, grads = jax.value_and_grad(problem.loss)(params)
        grads = jax.tree.map(lambda x: jnp.where(jnp.isfinite(x), x, 0.0), grads)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params = optax.projections.projection_box(
            params, problem.geometry.lower_bounds, problem.geometry.upper_bounds
        )

    arc_distance_deg = problem.geometry.arc_length_degrees(loss)
    return arc_distance_deg


random_seeds = {"p": 1, "m": -1}
arc_distance_deg = jnp.concatenate(
    [solve_with_ada(direction, k, kp, seed) for direction, seed in random_seeds.items()]
)
min_arc_length = jnp.min(arc_distance_deg)
average_size = coord.grid.average_pixel_size_degree
ratio = min_arc_length / average_size
print(f"{min_arc_length = :.4f}, {average_size = : .4f} {ratio = : .4f}")
min_arc_length = 14.0890, average_size =  19.5441 ratio =  0.7209
CPU times: user 23.6 s, sys: 31.4 ms, total: 23.6 s
Wall time: 23.5 s