Source code for wf_psf.training.training_config_handler
"""Training Config Handler.
A module which provides a class to manage the parameters of the training config file.
:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import os
import tensorflow as tf
from wf_psf.data.data_adapter import StructureState, RepresentationState, DataAdapter
from wf_psf.data.data_config_handler import DataConfigHandler
from wf_psf.data.factory import DataAdapterFactory
from wf_psf.data.schemas import DatasetMode
from wf_psf.metrics.metrics_config_handler import MetricsConfigHandler
from wf_psf.psf_models import psf_models
from wf_psf.utils.configs_handler import ConfigHandler, register_configclass
from wf_psf.utils.read_config import read_conf
from wf_psf.training import train
from wf_psf.training.training_data_adapter import TrainingDataAdapter
import logging
logger = logging.getLogger(__name__)
[docs]
@register_configclass
class TrainingConfigHandler(ConfigHandler):
"""TrainingConfigHandler.
A class to handle training configuration
parameters.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
training_conf: str
Path of the training configuration file
file_handler: object
A instance of the FileIOHandler class
"""
ids = ("training_conf",)
def __init__(self, training_conf, file_handler):
self.training_conf = read_conf(training_conf).training
self.file_handler = file_handler
self.data_params = DataConfigHandler(
os.path.join(file_handler.config_path, self.training_conf.data_config),
)
self.n_bins_lambda = self.training_conf.model_params.n_bins_lambda
self.simPSF = psf_models.simPSF(self.training_conf.model_params)
self.file_handler.copy_conffile_to_output_dir(self.training_conf.data_config)
self.checkpoint_dir = file_handler.get_checkpoint_dir(
self.file_handler._run_output_dir
)
self.optimizer_dir = file_handler.get_optimizer_dir(
self.file_handler._run_output_dir
)
self.psf_model_dir = file_handler.get_psf_model_dir(
self.file_handler._run_output_dir
)
[docs]
def run(self):
"""Run.
A function to run wavediff according to the
input configuration.
"""
training_adapter, psf_model = prepare_training_inputs(
self.data_params,
self.simPSF,
self.n_bins_lambda,
self.training_conf.training_hparams.loss,
self.training_conf.model_params,
self.training_conf.training_hparams,
)
train.train(
self.training_conf,
training_adapter,
psf_model,
self.checkpoint_dir,
self.optimizer_dir,
self.psf_model_dir,
)
if self.training_conf.metrics_config is not None:
self.file_handler.copy_conffile_to_output_dir(
self.training_conf.metrics_config
)
metrics = MetricsConfigHandler(
os.path.join(
self.file_handler.config_path,
self.training_conf.metrics_config,
),
self.file_handler,
self.training_conf,
)
metrics.run()
[docs]
def prepare_training_inputs(
data_params,
simPSF,
n_bins_lambda,
loss,
model_params,
training_hparams,
) -> tuple[TrainingDataAdapter, tf.keras.Model]:
"""Build a training-ready data adapter and PSF model.
The sequence is order-dependent: the dataset must be joined into
complete form before PSF model initialisation (certain models require
the full dataset), then split and converted to tensors afterward.
Parameters
----------
data_params : RecursiveNamespace or SHEPSFDataset
Data configuration parameters or a pre-loaded in-memory dataset.
simPSF : PSFSimulator
PSF simulator instance used for SED encoding during conversion.
n_bins_lambda : int
Number of wavelength bins for SED discretisation.
loss : str
Loss function identifier, determines whether masks are packed
with target images in the training adapter.
model_params : RecursiveNamespace
PSF model configuration parameters.
training_hparams : RecursiveNamespace
Training hyperparameters passed to PSF model initialisation.
Returns
-------
tuple[TrainingDataAdapter, PSFModel]
A fully prepared training adapter and initialised PSF model,
ready to be passed to the training loop.
"""
adapter = DataAdapterFactory.build(data_params)
if adapter.structure_state == StructureState.SPLIT:
logger.info("Joining split datasets...")
adapter.join_data()
# PSF model initialisation requires complete data
logger.info(f"Initialising PSF model {model_params.model_name}...")
psf_model = psf_models.get_psf_model(
model_params, training_hparams, adapter.complete_data
)
# Update StructureState to SPLIT
if adapter.structure_state == StructureState.COMPLETE:
logger.info("Generating split datasets...")
adapter.split_data()
# -------------------------------
# REPRESENTATION STATE MACHINE
# -------------------------------
if adapter.representation_state == RepresentationState.NUMPY:
logger.info("Converting dataset to tensors...")
adapter.convert_to_tensorflow(simPSF, n_bins_lambda, mode=DatasetMode.TRAIN)
# -------------------------------
# DATA_PREPARED BOUNDARY
# -------------------------------
logger.info("Validating data preparation state...")
_assert_data_prepared(adapter)
logger.info("Data preparation complete. Freezing dataset snapshot...")
return TrainingDataAdapter(adapter, loss), psf_model
def _assert_data_prepared(adapter: DataAdapter):
if adapter.structure_state != StructureState.SPLIT:
raise RuntimeError(f"Expected SPLIT data, got {adapter.structure_state}")
if adapter.representation_state != RepresentationState.TENSORFLOW:
raise RuntimeError(
f"Expected TensorFlow data, got {adapter.representation_state}"
)