Source code for wf_psf.data.schemas

"""
Dataset schema definitions for canonical dataset handling.

This module defines dataset validation and conversion schemas used across
training, evaluation, and inference workflows. Schemas specify which
canonical dataset fields are required, which are optional, and whether
missing required fields should raise an exception.

These schemas provide a centralized contract between dataset adapters,
preprocessing pipelines, and TensorFlow conversion utilities.

Canonical dataset field names are defined in ``constants.py`` and represent
the normalized internal dataset interface used throughout the library,
independent of external dataset naming conventions.

Author(s): Jennifer Pollack <jennifer.pollack@cea.fr>
"""

from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Callable, Any
from wf_psf.data.constants import (
    CANONICAL_DATASET_KEYS,
    OPTIONAL_KEYS as CONST_OPTIONAL_KEYS,
    SED_DOMAIN,
)
from wf_psf.data.handlers import process_seds_handler


[docs] class DatasetMode(Enum): """ Enumeration of supported dataset operation modes. These modes define the expected dataset contract for different stages of the wf-psf workflow. Attributes ---------- TRAIN Dataset schema used during model training. EVALUATION Dataset schema used during evaluation. INFERENCE Dataset schema used during inference or prediction. """ TRAIN = auto() EVALUATION = auto() INFERENCE = auto()
[docs] @dataclass(frozen=True) class DatasetSchema: """ Definition of a canonical dataset schema. A dataset schema specifies which canonical dataset fields are required, which fields are optional, and whether missing required fields should raise an exception during validation or conversion. Parameters ---------- id: str Schema identifier (e.g. "train", "evaluation", "inference") required_keys : tuple[str, ...] Canonical dataset fields that must be present. optional_keys : tuple[str, ...] Canonical dataset fields that may be present and will be processed if available. strict : bool, optional If ``True``, missing required fields raise an exception. If ``False``, missing required fields generate warnings and are skipped. Default is ``True``. handlers : dict[str, Callable[..., Any]] = None Handler for specific dataset fields (e.g. seds) """ id: str required_keys: tuple[str, ...] optional_keys: tuple[str, ...] strict: bool = True handlers: dict[str, Callable[..., Any]] = field(default_factory=dict)
TRAIN_SCHEMA = DatasetSchema( id="train", required_keys=CANONICAL_DATASET_KEYS, optional_keys=CONST_OPTIONAL_KEYS, strict=True, handlers={SED_DOMAIN: process_seds_handler} ) """ Dataset schema used during model training. All canonical dataset fields are required during training. Missing required fields raise exceptions. """ EVALUATION_SCHEMA = DatasetSchema( id="evaluation", required_keys=CANONICAL_DATASET_KEYS, optional_keys=CONST_OPTIONAL_KEYS, strict=True, handlers={"seds": process_seds_handler} ) """ Dataset schema used during model evaluation. All canonical dataset fields are required during evaluation. Missing required fields raise exceptions. """ INFERENCE_SCHEMA = DatasetSchema( id="inference", required_keys=( "seds", "positions", ), optional_keys=CONST_OPTIONAL_KEYS, strict=False, handlers={"seds": process_seds_handler} ) """ Dataset schema used during inference. Inference requires only the minimal subset of canonical fields needed for prediction. Missing required fields generate warnings rather than raising exceptions. """ SCHEMAS = { DatasetMode.TRAIN: TRAIN_SCHEMA, DatasetMode.EVALUATION: EVALUATION_SCHEMA, DatasetMode.INFERENCE: INFERENCE_SCHEMA, } """ Registry mapping dataset operation modes to dataset schemas. This dictionary provides centralized access to workflow-specific dataset contracts used throughout the preprocessing and conversion pipeline. """