Source code for wf_psf.data.tensorflow_converter

"""TensorFlow dataset converter for PSF datasets.

This module provides the `TensorFlowDatasetConverter` class, which handles the conversion of PSF datasets (both dataclass-based and dict-based) into TensorFlow tensors suitable for training, evaluation, and inference. It includes methods for processing SEDs using a PSF simulator and converting various dataset formats into a consistent TensorFlow format.

Author: Jennifer Pollack <jennifer.pollack@cea.fr>
"""

import tensorflow as tf
from typing import Union
from wf_psf.data.schemas import DatasetMode, SCHEMAS
from wf_psf.data.data_utils import DatasetContainer, ConversionContext, SEDContext
from wf_psf.psf_models.psf_models import PSFSimulator
from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor
from wf_psf.utils.utils import generate_SED_elems_in_tensorflow
import logging

logger = logging.getLogger(__name__)


[docs] class TensorFlowDatasetConverter: """ Convert structured dataset fields into TensorFlow-compatible tensors. This converter applies schema-driven preprocessing and validation to dataset fields used throughout the wf-psf pipeline. Conversion behavior is controlled by a :class:`DatasetSchema` selected through the corresponding :class:`DatasetMode`. Required and optional dataset fields are processed according to the active schema. Field-specific transformations are delegated to handler functions registered in the schema, while generic fields are converted directly to TensorFlow tensors. Runtime dependencies required by specialized handlers (e.g. SED processing) are provided through a :class:`ConversionContext`. Notes ----- - Required fields may raise exceptions or warnings depending on schema strictness. - Optional fields are processed only if present in the dataset. - Tensor conversion defaults to ``tf.float32`` unless overridden by a field-specific handler. - This converter currently targets TensorFlow workflows and produces TensorFlow tensor outputs suitable for training and inference. """ def _process_field( self, dataset, result, key, schema, context, required, ): """ Process a single dataset field according to schema rules. This method applies validation, handler dispatch, and TensorFlow tensor conversion for an individual dataset field. Specialized preprocessing is delegated to schema-registered handlers when available; otherwise, values are converted directly to tensors. Parameters ---------- dataset : DatasetContainer or dict Input dataset containing raw field values. result : dict Mutable output dictionary storing converted dataset fields. key : str Canonical dataset field name to process. schema : DatasetSchema Active dataset schema defining required fields, optional fields, strictness behavior, and field handlers. context : ConversionContext Runtime conversion context containing optional domain-specific processing dependencies. required : bool Whether the field is required under the active schema. Raises ------ ValueError Raised if: - A required field is missing while schema strictness is enabled. - A handler requires a domain-specific context that is absent. Notes ----- - Missing optional fields are silently ignored. - Missing required fields generate warnings when schema strictness is disabled. - Field handlers are resolved dynamically from the active schema. """ MISSING = object() v = dataset.get(key, MISSING) if v is MISSING: if required and schema.strict: raise ValueError( f"Dataset field '{key}' required for " f"{schema.id} is missing." ) if required: logger.warning( f"Dataset field '{key}' required for " f"{schema.id} is missing." ) return handler = schema.handlers.get(key) if handler is not None: domain_ctx = getattr(context, key, None) if domain_ctx is None: raise ValueError(f"Missing context for domain '{key}'") result[key] = handler(self, v, domain_ctx) else: result[key] = ensure_tensor(v, dtype=tf.float32)
[docs] def convert_dataset( self, dataset: Union[DatasetContainer, dict], simPSF: PSFSimulator, n_bins_lambda: int, mode: DatasetMode = DatasetMode.TRAIN, ): """ Convert a dataset into TensorFlow-compatible tensors. Dataset conversion is performed according to the schema associated with the selected :class:`DatasetMode`. Required and optional fields are processed independently, and field-specific preprocessing is delegated to registered schema handlers when applicable. A :class:`ConversionContext` is constructed internally and passed through the conversion pipeline to provide runtime dependencies required by specialized handlers. Parameters ---------- dataset : DatasetContainer or dict Input dataset containing raw arrays, metadata, and optional preprocessing fields. simPSF : PSFSimulator PSF simulator used for SED preprocessing operations. n_bins_lambda : int Number of wavelength bins used during SED discretization. mode : DatasetMode, default=DatasetMode.TRAIN Dataset conversion mode defining the active schema and validation behaviour. Returns ------- dict Dictionary containing TensorFlow tensors keyed by canonical dataset field names. Raises ------ ValueError Raised if: - A required dataset field is missing under strict schema validation. - A required domain-specific context is unavailable for a registered field handler. Notes ----- - Required fields are validated according to schema strictness. - Optional fields are processed only if present. - Generic fields are converted using ``ensure_tensor``. - Specialized preprocessing (e.g. SED conversion) is delegated to registered schema handlers. """ schema = SCHEMAS[mode] req_keys = schema.required_keys opt_keys = schema.optional_keys logger.info( f"Using dataset schema '{schema.id}' " f"(required_fields={len(req_keys)}, " f"optional_fields={len(opt_keys)})" ) result = dict(dataset) context = ConversionContext( seds=SEDContext(simPSF=simPSF, n_bins_lambda=n_bins_lambda) ) # Handle required keys for k in req_keys: self._process_field( dataset=dataset, result=result, key=k, schema=schema, context=context, required=True, ) # Handle optional keys for k in opt_keys: self._process_field( dataset=dataset, result=result, key=k, schema=schema, context=context, required=False, ) return result
[docs] @staticmethod def process_seds(sed_data, simPSF, n_bins_lambda): """ Process SEDs using simPSF and convert to TensorFlow tensors. This is a core operation that must be performed on all SED data before use in training or inference. Converts raw SED arrays into wavelength- binned TensorFlow tensors. Parameters ---------- sed_data : array_like Array of SEDs, shape (N, n_wavelengths) or similar simPSF : PSFSimulator PSF simulator used for SED processing. n_bins_lambda : int Number of wavelength bins for SED processing. Returns ------- tf.Tensor Processed SED tensor, shape (N, n_bins_lambda, n_components) Raises ------ ValueError If sed_data is None Notes ----- - Uses tf.float64 internally for precision during generation - Returns tf.float32 for training efficiency - Transposes to shape (N, n_bins_lambda, n_components) """ processed = [ generate_SED_elems_in_tensorflow( sed, simPSF, n_bins=n_bins_lambda, tf_dtype=tf.float64 ) for sed in sed_data ] sed_tensor = ensure_tensor(processed, dtype=tf.float32) return tf.transpose(sed_tensor, perm=[0, 2, 1])