wf_psf.training.train
Train.
A module which defines the classes and methods to manage training of the psf model.
- Authors:
Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>
Functions
|
Generate a file path for a checkpoint callback. |
Get GPU Information. |
|
Generate factory for loss, metrics, monitor, and outputs. |
|
Set up Training. |
|
|
Train the PSF model over one or more parametric and non-parametric training cycles. |
Classes
|
Training Parameters Handler. |
- class wf_psf.training.train.TrainingParamsHandler(training_params)[source]
Bases:
objectTraining Parameters Handler.
A class to handle training parameters accessed:
- Parameters:
training_params (Recursive Namespace object) – Recursive Namespace object containing training input parameters
- Attributes:
id_nameID Name.
learning_rate_non_paramsNon-parametric Model Learning Rate.
learning_rate_paramsParametric Model Learning Rate.
model_namePSF Model Name.
model_paramsPSF Model Params.
multi_cycle_paramsTraining Multi Cycle Parameters.
n_epochs_non_paramsNumber of Epochs for Non-parametric PSF model.
n_epochs_paramsNumber of Epochs for Parametric PSF model.
total_cyclesTotal Number of Cycles.
training_hparamsTraining Hyperparameters.
- property learning_rate_non_params
Non-parametric Model Learning Rate.
Set learning rate for non-parametric PSF model.
- Returns:
learning_rate_non_params – List containing learning rate for non-parametric PSF model
- Return type:
- property learning_rate_params
Parametric Model Learning Rate.
Set learning rate for parametric PSF model.
- Returns:
learning_rate_params – List containing learning rate for parametric PSF model
- Return type:
- property model_name
PSF Model Name.
Set model_name.
- Returns:
model_name – Name of PSF model
- Return type:
- property model_params
PSF Model Params.
Set PSF model training parameters
- Returns:
model_params – Recursive Namespace object storing PSF model parameters
- Return type:
Recursive Namespace object
- property multi_cycle_params
Training Multi Cycle Parameters.
Set training multi cycle parameters
- Returns:
multi_cycle_params – Recursive Namespace object storing training multi-cycle parameters
- Return type:
Recursive Namespace object
- property n_epochs_non_params
Number of Epochs for Non-parametric PSF model.
Set the number of epochs for training non-parametric PSF model.
- Returns:
n_epochs_non_params – List of number of epochs for training non-parametric PSF model.
- Return type:
- property n_epochs_params
Number of Epochs for Parametric PSF model.
Set the number of epochs for training parametric PSF model.
- Returns:
n_epochs_params – List of number of epochs for training parametric PSF model.
- Return type:
- property total_cycles
Total Number of Cycles.
Set total number of cycles for training.
- Returns:
total_cycles – Total number of cycles for training
- Return type:
- property training_hparams
Training Hyperparameters.
Set training hyperparameters
- Returns:
training_hparams – Recursive Namespace object storing training hyper parameters
- Return type:
Recursive Namespace object
- wf_psf.training.train.filepath_chkp_callback(checkpoint_dir: str, model_name: str, id_name: str, current_cycle: int) str[source]
Generate a file path for a checkpoint callback.
- Parameters:
- Returns:
A string representing the full file path for the checkpoint callback.
- Return type:
- wf_psf.training.train.get_gpu_info()[source]
Get GPU Information.
A function to return GPU device name.
- Returns:
device_name – Name of GPU device
- Return type:
- wf_psf.training.train.get_loss_metrics_monitor_and_outputs(training_handler, data_conf)[source]
Generate factory for loss, metrics, monitor, and outputs.
A function to generate loss, metrics, monitor, and outputs for training.
- Parameters:
training_handler (TrainingParamsHandler) – TrainingParamsHandler object containing training parameters
data_conf (object) – Data configuration object containing training and test data
- Returns:
loss (tf.keras.losses.Loss) – Loss function to be used for training
param_metrics (list) – List of metrics for the parametric model
non_param_metrics (list) – List of metrics for the non-parametric model
monitor (str) – Metric to monitor for saving the model
outputs (tf.Tensor) – Tensor containing the outputs for training
output_val (tf.Tensor) – Tensor containing the outputs for validation
- wf_psf.training.train.train(training_params, data_conf, checkpoint_dir, optimizer_dir, psf_model_dir)[source]
Train the PSF model over one or more parametric and non-parametric training cycles.
This function manages multi-cycle training of a parametric + non-parametric PSF model, including initialization, loss/metric configuration, optimizer setup, model checkpointing, and optional projection or resetting of non-parametric features. Each cycle can include both parametric and non-parametric training stages, and training history is saved for each.
- Parameters:
training_params (RecursiveNamespace) – Contains all training configuration parameters, including: - learning rates per cycle - number of epochs per component per cycle - model type and training behavior flags - multi-cycle definitions and callbacks
data_conf (object) – Contains training and validation datasets via attributes: - data_conf.training_data: TrainingDataHandler instance with SEDs and positions - data_conf.test_data: TestDataHandler instance with validation SEDs and positions
checkpoint_dir (str) – Directory where model checkpoints will be saved during training.
optimizer_dir (str) – Directory where the optimizer history (as a NumPy .npy file) will be stored.
psf_model_dir (str) – Directory where the final trained PSF model weights will be saved per cycle.
Notes
Utilizes TensorFlow and TensorFlow Addons for model training and optimization.
Supports masked mean squared error loss for training with masked data.
Allows for projection of data-driven features onto parametric models between cycles.
Supports resetting of non-parametric features to initial states.
Saves model weights to psf_model_dir per training cycle (or final one if not all saved)
Saves optimizer histories to optimizer_dir
Logs cycle information and time durations