wf_psf.data.tensorflow_converter
TensorFlow dataset converter for PSF datasets.
This module provides the TensorFlowDatasetConverter class, which handles the conversion of PSF datasets (both dataclass-based and dict-based) into TensorFlow tensors suitable for training, evaluation, and inference. It includes methods for processing SEDs using a PSF simulator and converting various dataset formats into a consistent TensorFlow format.
Author: Jennifer Pollack <jennifer.pollack@cea.fr>
Classes
Convert structured dataset fields into TensorFlow-compatible tensors. |
- class wf_psf.data.tensorflow_converter.TensorFlowDatasetConverter[source]
Bases:
objectConvert structured dataset fields into TensorFlow-compatible tensors.
This converter applies schema-driven preprocessing and validation to dataset fields used throughout the wf-psf pipeline. Conversion behavior is controlled by a
DatasetSchemaselected through the correspondingDatasetMode.Required and optional dataset fields are processed according to the active schema. Field-specific transformations are delegated to handler functions registered in the schema, while generic fields are converted directly to TensorFlow tensors.
Runtime dependencies required by specialized handlers (e.g. SED processing) are provided through a
ConversionContext.Notes
Required fields may raise exceptions or warnings depending on schema strictness.
Optional fields are processed only if present in the dataset.
Tensor conversion defaults to
tf.float32unless overridden by a field-specific handler.This converter currently targets TensorFlow workflows and produces TensorFlow tensor outputs suitable for training and inference.
Methods
convert_dataset(dataset, simPSF, n_bins_lambda)Convert a dataset into TensorFlow-compatible tensors.
process_seds(sed_data, simPSF, n_bins_lambda)Process SEDs using simPSF and convert to TensorFlow tensors.
- convert_dataset(dataset: DatasetContainer | dict, simPSF: PSFSimulator, n_bins_lambda: int, mode: DatasetMode = DatasetMode.TRAIN)[source]
Convert a dataset into TensorFlow-compatible tensors.
Dataset conversion is performed according to the schema associated with the selected
DatasetMode. Required and optional fields are processed independently, and field-specific preprocessing is delegated to registered schema handlers when applicable.A
ConversionContextis constructed internally and passed through the conversion pipeline to provide runtime dependencies required by specialized handlers.- Parameters:
dataset (DatasetContainer or dict) – Input dataset containing raw arrays, metadata, and optional preprocessing fields.
simPSF (PSFSimulator) – PSF simulator used for SED preprocessing operations.
n_bins_lambda (int) – Number of wavelength bins used during SED discretization.
mode (DatasetMode, default=DatasetMode.TRAIN) – Dataset conversion mode defining the active schema and validation behaviour.
- Returns:
Dictionary containing TensorFlow tensors keyed by canonical dataset field names.
- Return type:
- Raises:
Raised if:
A required dataset field is missing under strict schema validation.
A required domain-specific context is unavailable for a registered field handler.
Notes
Required fields are validated according to schema strictness.
Optional fields are processed only if present.
Generic fields are converted using
ensure_tensor.Specialized preprocessing (e.g. SED conversion) is delegated to registered schema handlers.
- static process_seds(sed_data, simPSF, n_bins_lambda)[source]
Process SEDs using simPSF and convert to TensorFlow tensors.
This is a core operation that must be performed on all SED data before use in training or inference. Converts raw SED arrays into wavelength- binned TensorFlow tensors.
- Parameters:
sed_data (array_like) – Array of SEDs, shape (N, n_wavelengths) or similar
simPSF (PSFSimulator) – PSF simulator used for SED processing.
n_bins_lambda (int) – Number of wavelength bins for SED processing.
- Returns:
Processed SED tensor, shape (N, n_bins_lambda, n_components)
- Return type:
tf.Tensor
- Raises:
ValueError – If sed_data is None
Notes
Uses tf.float64 internally for precision during generation
Returns tf.float32 for training efficiency
Transposes to shape (N, n_bins_lambda, n_components)