wf_psf.data.training_preprocessing

Training Data Processing.

A module to load and preprocess training and validation test data.

Authors:

Jennifer Pollack <jennifer.pollack@cea.fr> and Tobias Liaudat <tobiasliaudat@gmail.com>

Functions

compute_ccd_misalignment(model_params, data)

Compute CCD misalignment.

compute_centroid_correction(model_params, data)

Compute centroid corrections using Zernike polynomials.

extract_star_data(data, train_key, test_key)

Extract specific star-related data from training and test datasets.

get_np_obs_positions(data)

Get observed positions in numpy from the provided dataset.

get_np_zernike_prior(data)

Get the zernike prior from the provided dataset.

get_obs_positions(data)

Get observed positions from the provided dataset.

get_zernike_prior(model_params, data[, ...])

Get Zernike priors from the provided dataset.

Classes

DataHandler(dataset_type, data_params, ...)

Data Handler.

class wf_psf.data.training_preprocessing.DataHandler(dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True)[source]

Bases: object

Data Handler.

This class manages loading and processing of training and testing data for use during PSF model training and validation. It provides methods to access and preprocess the data.

Parameters:
  • dataset_type (str) – Type of dataset (“train” or “test”).

  • data_params (RecursiveNamespace) – Recursive Namespace object containing parameters for both ‘train’ and ‘test’ datasets.

  • simPSF (PSFSimulator) – Instance of the PSFSimulator class for simulating PSF models.

  • n_bins_lambda (int) – Number of wavelength bins for SED processing.

  • load_data (bool, optional) – If True, data is loaded and processed during initialization. If False, data loading is deferred until explicitly called. Default is True.

dataset_type

Type of dataset (“train” or “test”).

Type:

str

data_params

Parameters for the current dataset type.

Type:

RecursiveNamespace

dataset

Dictionary containing the loaded dataset, including positions and stars/noisy_stars.

Type:

dict or None

simPSF

Instance of the PSFSimulator class for simulating PSF models.

Type:

PSFSimulator

n_bins_lambda

Number of wavelength bins.

Type:

int

sed_data

TensorFlow tensor containing processed SED data for training/testing.

Type:

tf.Tensor or None

load_data_on_init

Flag controlling whether data is loaded during initialization.

Type:

bool

Methods

load_dataset()

Load dataset.

process_sed_data()

Process SED Data.

load_dataset()[source]

Load dataset.

Load the dataset based on the specified dataset type.

process_sed_data()[source]

Process SED Data.

A method to generate and process SED data.

wf_psf.data.training_preprocessing.compute_ccd_misalignment(model_params, data)[source]

Compute CCD misalignment.

Parameters:
Returns:

zernike_ccd_misalignment_array – Numpy array containing the Zernike contributions to model the CCD chip misalignments.

Return type:

np.ndarray

wf_psf.data.training_preprocessing.compute_centroid_correction(model_params, data, batch_size: int = 1) ndarray[source]

Compute centroid corrections using Zernike polynomials.

This function calculates the Zernike contributions required to match the centroid of the WaveDiff PSF model to the observed star centroids, processing in batches.

Parameters:
  • model_params (RecursiveNamespace) – An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters.

  • data (DataConfigHandler) – An object containing training and test datasets, including observed PSFs and optional star masks.

  • batch_size (int, optional) – The batch size to use when processing the stars. Default is 16.

Returns:

zernike_centroid_array – A 2D NumPy array of shape (n_stars, 3), where n_stars is the number of observed stars. The array contains the computed Zernike contributions, with zero padding applied to the first column to ensure a consistent shape.

Return type:

np.ndarray

wf_psf.data.training_preprocessing.extract_star_data(data, train_key: str, test_key: str) ndarray[source]

Extract specific star-related data from training and test datasets.

This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the star training and test datasets such as star stamps or masks, based on the provided keys.

Parameters:
  • data (DataConfigHandler) – Object containing training and test datasets.

  • train_key (str) – The key to retrieve data from the training dataset (e.g., ‘noisy_stars’, ‘masks’).

  • test_key (str) – The key to retrieve data from the test dataset (e.g., ‘stars’, ‘masks’).

Returns:

A NumPy array containing the concatenated data for the given keys.

Return type:

np.ndarray

Raises:

KeyError – If the specified keys do not exist in the training or test datasets.

Notes

  • If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays.

  • Ensure that eager execution is enabled when calling this function.

wf_psf.data.training_preprocessing.get_np_obs_positions(data)[source]

Get observed positions in numpy from the provided dataset.

This method concatenates the positions of the stars from both the training and test datasets to obtain the observed positions.

Parameters:

data (DataConfigHandler) – Object containing training and test datasets.

Returns:

Numpy array containing the observed positions of the stars.

Return type:

np.ndarray

Notes

The observed positions are obtained by concatenating the positions of stars from both the training and test datasets along the 0th axis.

wf_psf.data.training_preprocessing.get_np_zernike_prior(data)[source]

Get the zernike prior from the provided dataset.

This method concatenates the stars from both the training and test datasets to obtain the full prior.

Parameters:

data (DataConfigHandler) – Object containing training and test datasets.

Returns:

zernike_prior – Numpy array containing the full prior.

Return type:

np.ndarray

wf_psf.data.training_preprocessing.get_obs_positions(data)[source]

Get observed positions from the provided dataset.

Parameters:

data (DataConfigHandler) – Object containing training and test datasets.

Returns:

Tensor containing the observed positions of the stars.

Return type:

tf.Tensor

wf_psf.data.training_preprocessing.get_zernike_prior(model_params, data, batch_size: int = 16)[source]

Get Zernike priors from the provided dataset.

This method concatenates the Zernike priors from both the training and test datasets.

Parameters:
  • model_params (RecursiveNamespace) – Object containing parameters for this PSF model class.

  • data (DataConfigHandler) – Object containing training and test datasets.

  • batch_size (int, optional) – The batch size to use when processing the stars. Default is 16.

Returns:

Tensor containing the observed positions of the stars.

Return type:

tf.Tensor

Notes

The Zernike prior are obtained by concatenating the Zernike priors from both the training and test datasets along the 0th axis.