wf_psf.data.training_preprocessing
Training Data Processing.
A module to load and preprocess training and validation test data.
- Authors:
Jennifer Pollack <jennifer.pollack@cea.fr> and Tobias Liaudat <tobiasliaudat@gmail.com>
Functions
|
Compute CCD misalignment. |
|
Compute centroid corrections using Zernike polynomials. |
|
Extract specific star-related data from training and test datasets. |
|
Get observed positions in numpy from the provided dataset. |
|
Get the zernike prior from the provided dataset. |
|
Get observed positions from the provided dataset. |
|
Get Zernike priors from the provided dataset. |
Classes
|
Data Handler. |
- class wf_psf.data.training_preprocessing.DataHandler(dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True)[source]
Bases:
objectData Handler.
This class manages loading and processing of training and testing data for use during PSF model training and validation. It provides methods to access and preprocess the data.
- Parameters:
dataset_type (str) – Type of dataset (“train” or “test”).
data_params (RecursiveNamespace) – Recursive Namespace object containing parameters for both ‘train’ and ‘test’ datasets.
simPSF (PSFSimulator) – Instance of the PSFSimulator class for simulating PSF models.
n_bins_lambda (int) – Number of wavelength bins for SED processing.
load_data (bool, optional) – If True, data is loaded and processed during initialization. If False, data loading is deferred until explicitly called. Default is True.
- data_params
Parameters for the current dataset type.
- Type:
- dataset
Dictionary containing the loaded dataset, including positions and stars/noisy_stars.
- Type:
dict or None
- simPSF
Instance of the PSFSimulator class for simulating PSF models.
- Type:
- sed_data
TensorFlow tensor containing processed SED data for training/testing.
- Type:
tf.Tensor or None
Methods
Load dataset.
Process SED Data.
- wf_psf.data.training_preprocessing.compute_ccd_misalignment(model_params, data)[source]
Compute CCD misalignment.
- Parameters:
model_params (RecursiveNamespace) – Object containing parameters for this PSF model class.
data (DataConfigHandler) – Object containing training and test datasets.
- Returns:
zernike_ccd_misalignment_array – Numpy array containing the Zernike contributions to model the CCD chip misalignments.
- Return type:
np.ndarray
- wf_psf.data.training_preprocessing.compute_centroid_correction(model_params, data, batch_size: int = 1) ndarray[source]
Compute centroid corrections using Zernike polynomials.
This function calculates the Zernike contributions required to match the centroid of the WaveDiff PSF model to the observed star centroids, processing in batches.
- Parameters:
model_params (RecursiveNamespace) – An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters.
data (DataConfigHandler) – An object containing training and test datasets, including observed PSFs and optional star masks.
batch_size (int, optional) – The batch size to use when processing the stars. Default is 16.
- Returns:
zernike_centroid_array – A 2D NumPy array of shape (n_stars, 3), where n_stars is the number of observed stars. The array contains the computed Zernike contributions, with zero padding applied to the first column to ensure a consistent shape.
- Return type:
np.ndarray
- wf_psf.data.training_preprocessing.extract_star_data(data, train_key: str, test_key: str) ndarray[source]
Extract specific star-related data from training and test datasets.
This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the star training and test datasets such as star stamps or masks, based on the provided keys.
- Parameters:
data (DataConfigHandler) – Object containing training and test datasets.
train_key (str) – The key to retrieve data from the training dataset (e.g., ‘noisy_stars’, ‘masks’).
test_key (str) – The key to retrieve data from the test dataset (e.g., ‘stars’, ‘masks’).
- Returns:
A NumPy array containing the concatenated data for the given keys.
- Return type:
np.ndarray
- Raises:
KeyError – If the specified keys do not exist in the training or test datasets.
Notes
If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays.
Ensure that eager execution is enabled when calling this function.
- wf_psf.data.training_preprocessing.get_np_obs_positions(data)[source]
Get observed positions in numpy from the provided dataset.
This method concatenates the positions of the stars from both the training and test datasets to obtain the observed positions.
- Parameters:
data (DataConfigHandler) – Object containing training and test datasets.
- Returns:
Numpy array containing the observed positions of the stars.
- Return type:
np.ndarray
Notes
The observed positions are obtained by concatenating the positions of stars from both the training and test datasets along the 0th axis.
- wf_psf.data.training_preprocessing.get_np_zernike_prior(data)[source]
Get the zernike prior from the provided dataset.
This method concatenates the stars from both the training and test datasets to obtain the full prior.
- Parameters:
data (DataConfigHandler) – Object containing training and test datasets.
- Returns:
zernike_prior – Numpy array containing the full prior.
- Return type:
np.ndarray
- wf_psf.data.training_preprocessing.get_obs_positions(data)[source]
Get observed positions from the provided dataset.
- Parameters:
data (DataConfigHandler) – Object containing training and test datasets.
- Returns:
Tensor containing the observed positions of the stars.
- Return type:
tf.Tensor
- wf_psf.data.training_preprocessing.get_zernike_prior(model_params, data, batch_size: int = 16)[source]
Get Zernike priors from the provided dataset.
This method concatenates the Zernike priors from both the training and test datasets.
- Parameters:
model_params (RecursiveNamespace) – Object containing parameters for this PSF model class.
data (DataConfigHandler) – Object containing training and test datasets.
batch_size (int, optional) – The batch size to use when processing the stars. Default is 16.
- Returns:
Tensor containing the observed positions of the stars.
- Return type:
tf.Tensor
Notes
The Zernike prior are obtained by concatenating the Zernike priors from both the training and test datasets along the 0th axis.