wf_psf.data.data_zernike_utils

Utilities for Zernike Data Handling.

This module provides utility functions for working with Zernike coefficients, including:

  • Prior generation

  • Data loading

  • Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients

Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes.

Author:

Tobias Liaudat <tobias.liaudat@cea.fr>

Functions

assemble_zernike_contributions(model_params)

Assemble Zernike contributions from prior, centroid correction, and CCD misalignment.

combine_zernike_contributions(contributions)

Combine multiple Zernike contributions, padding each to the max order before summing.

compute_zernike_tip_tilt(star_images[, ...])

Compute Zernike tip-tilt corrections for a batch of PSF images.

defocus_to_zk4_wavediff(dz[, ...])

Compute Zernike 4 value for a given defocus in WaveDifff conventions.

defocus_to_zk4_zemax(dz[, tel_focal_length, ...])

Compute Zernike 4 value for a given defocus in zemax conventions.

get_np_zernike_prior(data)

Get the zernike prior from the provided dataset.

pad_contribution_to_order(contribution, ...)

Pad a Zernike contribution array to the max Zernike order.

pad_tf_zernikes(zk_param, zk_prior, n_zks_total)

Pad the Zernike coefficient tensors to match the specified total number of Zernikes.

shift_x_y_to_zk1_2_wavediff(dxy[, ...])

Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions.

Classes

ZernikeInputs(zernike_prior, ...)

Zernike-related inputs for PSF modeling, including priors and datasets for corrections.

ZernikeInputsFactory()

Factory class to build ZernikeInputs based on run type and dataset configuration.

class wf_psf.data.data_zernike_utils.ZernikeInputs(zernike_prior: ndarray | None, centroid_dataset: dict | RecursiveNamespace | None, misalignment_positions: ndarray | None)[source]

Bases: object

Zernike-related inputs for PSF modeling, including priors and datasets for corrections.

All fields are optional to allow flexibility across different run types (training, simulation, inference) and configurations.

Parameters:
  • zernike_prior (Optional[np.ndarray]) – The true Zernike prior, if provided (e.g., from PDC). Can be None if not used or not available.

  • centroid_dataset (Optional[Union[dict, "RecursiveNamespace"]]) – Dataset used for computing centroid corrections. Should contain both training and test sets if used. Can be None if centroid correction is not enabled or no dataset is available.

  • misalignment_positions (Optional[np.ndarray]) – Positions used for computing CCD misalignment corrections. Should be available in inference mode if misalignment correction is enabled. Can be None if not used or not available.

centroid_dataset: dict | RecursiveNamespace | None
misalignment_positions: ndarray | None
zernike_prior: ndarray | None
class wf_psf.data.data_zernike_utils.ZernikeInputsFactory[source]

Bases: object

Factory class to build ZernikeInputs based on run type and dataset configuration.

This class abstracts the logic of extracting the relevant Zernike-related inputs from the dataset based on the specified run type (training, simulation, inference) and model parameters. It handles the conditional logic for which inputs are needed and how to extract them, providing a clean interface for constructing the ZernikeInputs dataclass instance.

Methods

build(data, run_type, model_params[, prior])

Build a ZernikeInputs dataclass instance based on run type and data.

static build(data, run_type: str, model_params, prior: ndarray | None = None) ZernikeInputs[source]

Build a ZernikeInputs dataclass instance based on run type and data.

Parameters:
  • data (Union[dict, DataConfigHandler]) – Dataset object containing star positions, priors, and optionally pixel data.

  • run_type (str) – One of ‘training’, ‘simulation’, or ‘inference’.

  • model_params (RecursiveNamespace) – Model parameters, including flags for prior/corrections.

  • prior (Optional[np.ndarray]) – An explicitly passed prior (overrides any inferred one if provided).

Return type:

ZernikeInputs

wf_psf.data.data_zernike_utils.assemble_zernike_contributions(model_params, zernike_prior=None, centroid_dataset=None, positions=None, batch_size=16)[source]

Assemble Zernike contributions from prior, centroid correction, and CCD misalignment.

This function checks the model parameters to determine which contributions to include, computes each contribution as needed, and combines them into a single Zernike contribution tensor. It handles the logic for when certain contributions are not used or not available, ensuring that the final output is correctly shaped and contains the appropriate information based on the configuration.

Parameters:
  • model_params (RecursiveNamespace) – Parameters controlling which contributions to apply.

  • zernike_prior (Optional[np.ndarray or tf.Tensor]) – The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. If a Tensor, will be converted to NumPy in eager mode.

  • centroid_dataset (Optional[object]) – Dataset used to compute centroid correction. Must have both training and test sets.

  • positions (Optional[np.ndarray or tf.Tensor]) – Positions used for computing CCD misalignment. Must be available in inference mode.

  • batch_size (int) – Batch size for centroid correction.

Returns:

A tensor representing the full Zernike contribution map.

Return type:

tf.Tensor

wf_psf.data.data_zernike_utils.combine_zernike_contributions(contributions: list[ndarray]) ndarray[source]

Combine multiple Zernike contributions, padding each to the max order before summing.

wf_psf.data.data_zernike_utils.compute_zernike_tip_tilt(star_images: ndarray, star_masks: ndarray | None = None, pixel_sampling: float = 1.2e-05, reference_shifts: list[float] = [-0.3333333333333333, -0.3333333333333333], sigma_init: float = 2.5, n_iter: int = 20) ndarray[source]

Compute Zernike tip-tilt corrections for a batch of PSF images.

This function estimates the centroid shifts of multiple PSFs and computes the corresponding Zernike tip-tilt corrections to align them with a reference.

Parameters:
  • star_images (np.ndarray) – A batch of PSF images (3D array of shape (num_images, height, width)).

  • star_masks (np.ndarray, optional) – A batch of masks (same shape as star_postage_stamps). Each mask can have: - 0 to ignore the pixel. - 1 to fully consider the pixel. - Values in (0,1] as weights for partial consideration. Defaults to None.

  • pixel_sampling (float, optional) – The pixel size in meters. Defaults to 12e-6 m (12 microns).

  • reference_shifts (list[float], optional) – The target centroid shifts in pixels, specified as [dy, dx]. Defaults to [-1/3, -1/3] (nominal Euclid conditions).

  • sigma_init (float, optional) – Initial standard deviation for centroid estimation. Default is 2.5.

  • n_iter (int, optional) – Number of iterations for centroid refinement. Default is 20.

Returns:

An array of shape (num_images, 2), where: - Column 0 contains Zk1 (tip) values. - Column 1 contains Zk2 (tilt) values.

Return type:

np.ndarray

Notes

  • This function processes all images at once using vectorized operations.

  • The Zernike coefficients are computed in the WaveDiff convention.

wf_psf.data.data_zernike_utils.defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2)[source]

Compute Zernike 4 value for a given defocus in WaveDifff conventions.

All inputs should be in [m].

The output zernike coefficient is in [um] units as expected by wavediff.

Parameters:
  • dz (float) – Shift in the z-axis, perpendicular to the focal plane. Units in [m].

  • tel_focal_length (float) – Telescope focal length in [m].

  • tel_diameter (float) – Telescope aperture diameter in [m].

wf_psf.data.data_zernike_utils.defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2)[source]

Compute Zernike 4 value for a given defocus in zemax conventions.

All inputs should be in [m].

Parameters:
  • dz (float) – Shift in the z-axis, perpendicular to the focal plane. Units in [m].

  • tel_focal_length (float) – Telescope focal length in [m].

  • tel_diameter (float) – Telescope aperture diameter in [m].

wf_psf.data.data_zernike_utils.get_np_zernike_prior(data)[source]

Get the zernike prior from the provided dataset.

This method concatenates the stars from both the training and test datasets to obtain the full prior.

Parameters:

data (DataConfigHandler) – Object containing training and test datasets.

Returns:

zernike_prior – Numpy array containing the full prior.

Return type:

np.ndarray

wf_psf.data.data_zernike_utils.pad_contribution_to_order(contribution: ndarray, max_order: int) ndarray[source]

Pad a Zernike contribution array to the max Zernike order.

wf_psf.data.data_zernike_utils.pad_tf_zernikes(zk_param: Tensor, zk_prior: Tensor, n_zks_total: int)[source]

Pad the Zernike coefficient tensors to match the specified total number of Zernikes.

Parameters:
  • zk_param (tf.Tensor) – Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1].

  • zk_prior (tf.Tensor) – Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1].

  • n_zks_total (int) – Total number of Zernikes to pad to.

Returns:

  • padded_zk_param (tf.Tensor) – Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1].

  • padded_zk_prior (tf.Tensor) – Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1].

wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2)[source]

Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions.

All inputs should be in [m]. A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, e.g. 12[um], to get a displacement in [m], which would be dxy=0.5*12e-6.

The output zernike coefficient is in [um] units as expected by wavediff.

To apply match the centroid with a dx that has a corresponding zk1, the new PSF should be generated with -zk1.

The same applies to dy and zk2.

Parameters:
  • dxy (float) – Centroid shift in [m]. It can be on the x-axis or the y-axis.

  • tel_focal_length (float) – Telescope focal length in [m].

  • tel_diameter (float) – Telescope aperture diameter in [m].