wf_psf.training.train

Train.

A module which defines the classes and methods to manage training of the psf model.

Authors:

Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>

Functions

filepath_chkp_callback(checkpoint_dir, ...)

Generate a file path for a checkpoint callback.

get_gpu_info()

Get GPU Information.

get_loss_metrics_monitor_and_outputs(...)

Generate factory for loss, metrics, monitor, and outputs.

setup_training()

Set up Training.

train(training_params, data_conf, ...)

Train the PSF model over one or more parametric and non-parametric training cycles.

Classes

TrainingParamsHandler(training_params)

Training Parameters Handler.

class wf_psf.training.train.TrainingParamsHandler(training_params)[source]

Bases: object

Training Parameters Handler.

A class to handle training parameters accessed:

Parameters:

training_params (Recursive Namespace object) – Recursive Namespace object containing training input parameters

Attributes:
id_name

ID Name.

learning_rate_non_params

Non-parametric Model Learning Rate.

learning_rate_params

Parametric Model Learning Rate.

model_name

PSF Model Name.

model_params

PSF Model Params.

multi_cycle_params

Training Multi Cycle Parameters.

n_epochs_non_params

Number of Epochs for Non-parametric PSF model.

n_epochs_params

Number of Epochs for Parametric PSF model.

total_cycles

Total Number of Cycles.

training_hparams

Training Hyperparameters.

property id_name

ID Name.

Set unique ID name.

Returns:

id_name – A unique ID.

Return type:

str

property learning_rate_non_params

Non-parametric Model Learning Rate.

Set learning rate for non-parametric PSF model.

Returns:

learning_rate_non_params – List containing learning rate for non-parametric PSF model

Return type:

list

property learning_rate_params

Parametric Model Learning Rate.

Set learning rate for parametric PSF model.

Returns:

learning_rate_params – List containing learning rate for parametric PSF model

Return type:

list

property model_name

PSF Model Name.

Set model_name.

Returns:

model_name – Name of PSF model

Return type:

str

property model_params

PSF Model Params.

Set PSF model training parameters

Returns:

model_params – Recursive Namespace object storing PSF model parameters

Return type:

Recursive Namespace object

property multi_cycle_params

Training Multi Cycle Parameters.

Set training multi cycle parameters

Returns:

multi_cycle_params – Recursive Namespace object storing training multi-cycle parameters

Return type:

Recursive Namespace object

property n_epochs_non_params

Number of Epochs for Non-parametric PSF model.

Set the number of epochs for training non-parametric PSF model.

Returns:

n_epochs_non_params – List of number of epochs for training non-parametric PSF model.

Return type:

list

property n_epochs_params

Number of Epochs for Parametric PSF model.

Set the number of epochs for training parametric PSF model.

Returns:

n_epochs_params – List of number of epochs for training parametric PSF model.

Return type:

list

property total_cycles

Total Number of Cycles.

Set total number of cycles for training.

Returns:

total_cycles – Total number of cycles for training

Return type:

int

property training_hparams

Training Hyperparameters.

Set training hyperparameters

Returns:

training_hparams – Recursive Namespace object storing training hyper parameters

Return type:

Recursive Namespace object

wf_psf.training.train.filepath_chkp_callback(checkpoint_dir: str, model_name: str, id_name: str, current_cycle: int) str[source]

Generate a file path for a checkpoint callback.

Parameters:
  • checkpoint_dir (str) – The directory where the checkpoint will be saved.

  • model_name (str) – The name of the model.

  • id_name (str) – The unique identifier for the model instance.

  • current_cycle (int) – The current cycle number.

Returns:

A string representing the full file path for the checkpoint callback.

Return type:

str

wf_psf.training.train.get_gpu_info()[source]

Get GPU Information.

A function to return GPU device name.

Returns:

device_name – Name of GPU device

Return type:

str

wf_psf.training.train.get_loss_metrics_monitor_and_outputs(training_handler, data_conf)[source]

Generate factory for loss, metrics, monitor, and outputs.

A function to generate loss, metrics, monitor, and outputs for training.

Parameters:
  • training_handler (TrainingParamsHandler) – TrainingParamsHandler object containing training parameters

  • data_conf (object) – Data configuration object containing training and test data

Returns:

  • loss (tf.keras.losses.Loss) – Loss function to be used for training

  • param_metrics (list) – List of metrics for the parametric model

  • non_param_metrics (list) – List of metrics for the non-parametric model

  • monitor (str) – Metric to monitor for saving the model

  • outputs (tf.Tensor) – Tensor containing the outputs for training

  • output_val (tf.Tensor) – Tensor containing the outputs for validation

wf_psf.training.train.setup_training()[source]

Set up Training.

A function to setup training.

wf_psf.training.train.train(training_params, data_conf, checkpoint_dir, optimizer_dir, psf_model_dir)[source]

Train the PSF model over one or more parametric and non-parametric training cycles.

This function manages multi-cycle training of a parametric + non-parametric PSF model, including initialization, loss/metric configuration, optimizer setup, model checkpointing, and optional projection or resetting of non-parametric features. Each cycle can include both parametric and non-parametric training stages, and training history is saved for each.

Parameters:
  • training_params (RecursiveNamespace) – Contains all training configuration parameters, including: - learning rates per cycle - number of epochs per component per cycle - model type and training behavior flags - multi-cycle definitions and callbacks

  • data_conf (object) – Contains training and validation datasets via attributes: - data_conf.training_data: TrainingDataHandler instance with SEDs and positions - data_conf.test_data: TestDataHandler instance with validation SEDs and positions

  • checkpoint_dir (str) – Directory where model checkpoints will be saved during training.

  • optimizer_dir (str) – Directory where the optimizer history (as a NumPy .npy file) will be stored.

  • psf_model_dir (str) – Directory where the final trained PSF model weights will be saved per cycle.

Notes

  • Utilizes TensorFlow and TensorFlow Addons for model training and optimization.

  • Supports masked mean squared error loss for training with masked data.

  • Allows for projection of data-driven features onto parametric models between cycles.

  • Supports resetting of non-parametric features to initial states.

  • Saves model weights to psf_model_dir per training cycle (or final one if not all saved)

  • Saves optimizer histories to optimizer_dir

  • Logs cycle information and time durations