Skip to content

shine.scene

Probabilistic scene builder for NumPyro models.

Translates YAML configuration into a differentiable forward model that renders galaxies with JAX-GalSim and evaluates the likelihood against observed data.

scene

SceneBuilder(config)

Builder for NumPyro probabilistic scene models.

Constructs the forward generative model for Bayesian shear inference by translating configuration parameters into NumPyro priors and building a differentiable rendering pipeline using JAX-GalSim.

Attributes:

Name Type Description
config

SHINE configuration object containing model specifications.

Initialize the scene builder.

Parameters:

Name Type Description Default
config ShineConfig

SHINE configuration object.

required
Source code in shine/scene.py
def __init__(self, config: ShineConfig) -> None:
    """Initialize the scene builder.

    Args:
        config: SHINE configuration object.
    """
    self.config = config

build_model()

Build the NumPyro forward generative model for inference.

Returns:

Type Description
Callable

A NumPyro model function that can be passed to MCMC samplers.

Source code in shine/scene.py
def build_model(self) -> Callable:
    """Build the NumPyro forward generative model for inference.

    Returns:
        A NumPyro model function that can be passed to MCMC samplers.
    """

    def model(
        observed_data: Optional[jnp.ndarray] = None, psf: Optional[Any] = None
    ) -> None:
        """NumPyro probabilistic model for shear inference.

        Args:
            observed_data: Observed image data (used as obs in likelihood).
            psf: Pre-built JAX-GalSim PSF object to avoid reconstruction overhead.
        """
        fft_size = self.config.image.fft_size
        gsparams = galsim.GSParams(
            maximum_fft_size=fft_size, minimum_fft_size=fft_size
        )
        img_cfg = self.config.image
        gal_cfg = self.config.gal

        # 1. Global shear parameters
        g1 = self._parse_prior("g1", gal_cfg.shear.g1)
        g2 = self._parse_prior("g2", gal_cfg.shear.g2)
        shear = galsim.Shear(g1=g1, g2=g2)

        # 2. Galaxy population
        with numpyro.plate("galaxies", img_cfg.n_objects):
            flux = self._parse_prior("flux", gal_cfg.flux)
            hlr = self._parse_prior("hlr", gal_cfg.half_light_radius)

            # Intrinsic ellipticity
            e1 = 0.0
            e2 = 0.0
            if gal_cfg.ellipticity is not None:
                e1 = self._parse_prior("e1", gal_cfg.ellipticity.e1)
                e2 = self._parse_prior("e2", gal_cfg.ellipticity.e2)

            # Position priors
            x_min, x_max, y_min, y_max = self._resolve_position_bounds()

            x = numpyro.sample("x", dist.Uniform(x_min, x_max))
            y = numpyro.sample("y", dist.Uniform(y_min, y_max))

        # 3. Differentiable rendering
        def render_one_galaxy(
            flux: float, hlr: float, e1: float, e2: float, x: float, y: float
        ) -> jnp.ndarray:
            """Render a single galaxy image with JAX-GalSim."""
            gal = galaxy_utils.get_jax_galaxy(
                gal_cfg, flux, hlr, e1, e2, gsparams=gsparams
            )
            gal = gal.shear(shear)
            gal = galsim.Convolve([gal, psf], gsparams=gsparams)
            return gal.drawImage(
                nx=img_cfg.size_x,
                ny=img_cfg.size_y,
                scale=img_cfg.pixel_scale,
                offset=(
                    x - img_cfg.size_x / 2 + 0.5,
                    y - img_cfg.size_y / 2 + 0.5,
                ),
            ).array

        flux = jnp.atleast_1d(flux)
        hlr = jnp.atleast_1d(hlr)
        e1 = jnp.atleast_1d(e1)
        e2 = jnp.atleast_1d(e2)
        x = jnp.atleast_1d(x)
        y = jnp.atleast_1d(y)

        galaxy_images = jax.vmap(render_one_galaxy)(flux, hlr, e1, e2, x, y)
        model_image = jnp.sum(galaxy_images, axis=0)

        # 4. Likelihood
        sigma = img_cfg.noise.sigma
        numpyro.sample("obs", dist.Normal(model_image, sigma), obs=observed_data)

    return model

build_batched_model(n_batch)

Build a batched NumPyro model for N independent realizations.

Each batch element has independent shear and galaxy parameters, but shares the same PSF. The rendering is vmapped over the batch dimension for GPU-parallel execution.

Only supports n_objects=1 (Level 0 case). Multi-object batching is left for a future enhancement.

Parameters:

Name Type Description Default
n_batch int

Number of independent realizations to batch.

required

Returns:

Type Description
Callable

A NumPyro model function with signature model(observed_data, psf).

Source code in shine/scene.py
def build_batched_model(self, n_batch: int) -> Callable:
    """Build a batched NumPyro model for N independent realizations.

    Each batch element has independent shear and galaxy parameters, but
    shares the same PSF. The rendering is vmapped over the batch dimension
    for GPU-parallel execution.

    Only supports n_objects=1 (Level 0 case). Multi-object batching is
    left for a future enhancement.

    Args:
        n_batch: Number of independent realizations to batch.

    Returns:
        A NumPyro model function with signature model(observed_data, psf).
    """
    if self.config.image.n_objects != 1:
        raise ValueError(
            f"build_batched_model only supports n_objects=1, "
            f"got {self.config.image.n_objects}"
        )

    def model(
        observed_data: Optional[jnp.ndarray] = None, psf: Optional[Any] = None
    ) -> None:
        """Batched NumPyro model for parallel shear inference.

        Args:
            observed_data: Stacked observed images, shape (n_batch, nx, ny).
            psf: Pre-built JAX-GalSim PSF object (shared across batch).
        """
        fft_size = self.config.image.fft_size
        gsparams = galsim.GSParams(
            maximum_fft_size=fft_size, minimum_fft_size=fft_size
        )
        img_cfg = self.config.image
        gal_cfg = self.config.gal

        # Sample all parameters under a batch plate — each batch element
        # gets independent draws, and NUTS explores a 6*N-dim space.
        with numpyro.plate("batch", n_batch):
            g1 = self._parse_prior("g1", gal_cfg.shear.g1)
            g2 = self._parse_prior("g2", gal_cfg.shear.g2)

            flux = self._parse_prior("flux", gal_cfg.flux)
            hlr = self._parse_prior("hlr", gal_cfg.half_light_radius)

            e1 = 0.0
            e2 = 0.0
            if gal_cfg.ellipticity is not None:
                e1 = self._parse_prior("e1", gal_cfg.ellipticity.e1)
                e2 = self._parse_prior("e2", gal_cfg.ellipticity.e2)

            x_min, x_max, y_min, y_max = self._resolve_position_bounds()
            x = numpyro.sample("x", dist.Uniform(x_min, x_max))
            y = numpyro.sample("y", dist.Uniform(y_min, y_max))

        # vmap rendering over batch dimension
        def render_single_scene(
            g1: float,
            g2: float,
            flux: float,
            hlr: float,
            e1: float,
            e2: float,
            x: float,
            y: float,
        ) -> jnp.ndarray:
            """Render a single scene (one galaxy, one shear)."""
            from shine import galaxy_utils as gu

            gal = gu.get_jax_galaxy(
                gal_cfg, flux, hlr, e1, e2, gsparams=gsparams
            )
            shear = galsim.Shear(g1=g1, g2=g2)
            gal = gal.shear(shear)
            final = galsim.Convolve([gal, psf], gsparams=gsparams)
            return final.drawImage(
                nx=img_cfg.size_x,
                ny=img_cfg.size_y,
                scale=img_cfg.pixel_scale,
                offset=(
                    x - img_cfg.size_x / 2 + 0.5,
                    y - img_cfg.size_y / 2 + 0.5,
                ),
            ).array

        # Ensure all parameters have shape (n_batch,). Sampled params
        # already have this shape from the plate; fixed params need
        # broadcasting from scalar to (n_batch,).
        def _ensure_batch(val: Any) -> jnp.ndarray:
            arr = jnp.atleast_1d(jnp.asarray(val))
            return jnp.broadcast_to(arr, (n_batch,))

        g1 = _ensure_batch(g1)
        g2 = _ensure_batch(g2)
        flux = _ensure_batch(flux)
        hlr = _ensure_batch(hlr)
        e1 = _ensure_batch(e1)
        e2 = _ensure_batch(e2)
        x = _ensure_batch(x)
        y = _ensure_batch(y)

        model_images = jax.vmap(render_single_scene)(
            g1, g2, flux, hlr, e1, e2, x, y
        )  # shape: (n_batch, nx, ny)

        # Likelihood — .to_event(2) converts (nx, ny) to event dims so
        # each batch element has its own full-image likelihood.
        sigma = img_cfg.noise.sigma
        numpyro.sample(
            "obs",
            dist.Normal(model_images, sigma).to_event(2),
            obs=observed_data,
        )

    return model