wf_psf.training.train
Train.
A module which defines the classes and methods to manage training of the psf model.
- Authors:
Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>
Functions
|
Generate a file path for a checkpoint callback. |
Get GPU Information. |
|
Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. |
|
Set up Training. |
|
|
Train the PSF model over one or more parametric and non-parametric training cycles. |
Classes
|
Training Parameters Handler. |
- class wf_psf.training.train.TrainingParamsHandler(training_params)[source]
Bases:
objectTraining Parameters Handler.
A class to handle training parameters accessed:
- Parameters:
training_params (Recursive Namespace object) – Recursive Namespace object containing training input parameters
- Attributes:
id_nameID Name.
learning_rate_non_paramsNon-parametric Model Learning Rate.
learning_rate_paramsParametric Model Learning Rate.
model_namePSF Model Name.
model_paramsPSF Model Params.
multi_cycle_paramsTraining Multi Cycle Parameters.
n_epochs_non_paramsNumber of Epochs for Non-parametric PSF model.
n_epochs_paramsNumber of Epochs for Parametric PSF model.
total_cyclesTotal Number of Cycles.
training_hparamsTraining Hyperparameters.
- property learning_rate_non_params
Non-parametric Model Learning Rate.
Set learning rate for non-parametric PSF model.
- Returns:
learning_rate_non_params – List containing learning rate for non-parametric PSF model
- Return type:
- property learning_rate_params
Parametric Model Learning Rate.
Set learning rate for parametric PSF model.
- Returns:
learning_rate_params – List containing learning rate for parametric PSF model
- Return type:
- property model_name
PSF Model Name.
Set model_name.
- Returns:
model_name – Name of PSF model
- Return type:
- property model_params
PSF Model Params.
Set PSF model training parameters
- Returns:
model_params – Recursive Namespace object storing PSF model parameters
- Return type:
Recursive Namespace object
- property multi_cycle_params
Training Multi Cycle Parameters.
Set training multi cycle parameters
- Returns:
multi_cycle_params – Recursive Namespace object storing training multi-cycle parameters
- Return type:
Recursive Namespace object
- property n_epochs_non_params
Number of Epochs for Non-parametric PSF model.
Set the number of epochs for training non-parametric PSF model.
- Returns:
n_epochs_non_params – List of number of epochs for training non-parametric PSF model.
- Return type:
- property n_epochs_params
Number of Epochs for Parametric PSF model.
Set the number of epochs for training parametric PSF model.
- Returns:
n_epochs_params – List of number of epochs for training parametric PSF model.
- Return type:
- property total_cycles
Total Number of Cycles.
Set total number of cycles for training.
- Returns:
total_cycles – Total number of cycles for training
- Return type:
- property training_hparams
Training Hyperparameters.
Set training hyperparameters
- Returns:
training_hparams – Recursive Namespace object storing training hyper parameters
- Return type:
Recursive Namespace object
- wf_psf.training.train.filepath_chkp_callback(checkpoint_dir: str, model_name: str, id_name: str, current_cycle: int) str[source]
Generate a file path for a checkpoint callback.
- Parameters:
- Returns:
A string representing the full file path for the checkpoint callback.
- Return type:
- wf_psf.training.train.get_gpu_info()[source]
Get GPU Information.
A function to return GPU device name.
- Returns:
device_name – Name of GPU device
- Return type:
- wf_psf.training.train.get_loss_metrics_monitor_and_outputs(training_handler, data_conf)[source]
Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle.
- Parameters:
training_handler (TrainingParamsHandler) – TrainingParamsHandler object containing training parameters
data_conf (object) – Data configuration object containing training and test data
- Returns:
loss (tf.keras.losses.Loss) – Loss function to be used for training
param_metrics (list) – List of metrics for the parametric model
non_param_metrics (list) – List of metrics for the non-parametric model
monitor (str) – Metric to monitor for saving the model
outputs (tf.Tensor) – Tensor containing the outputs for training
output_val (tf.Tensor) – Tensor containing the outputs for validation
- wf_psf.training.train.train(training_params, data_conf, checkpoint_dir, optimizer_dir, psf_model_dir)[source]
Train the PSF model over one or more parametric and non-parametric training cycles.
This function manages multi-cycle training of a parametric + non-parametric PSF model, including initialization, loss/metric configuration, optimizer setup, model checkpointing, and optional projection or resetting of non-parametric features. Each cycle can include both parametric and non-parametric training stages, and training history is saved for each.
- Parameters:
training_params (RecursiveNamespace) – Contains all training configuration parameters, including: - learning rates per cycle - number of epochs per component per cycle - model type and training behavior flags - multi-cycle definitions and callbacks
data_conf (object) – Contains training and validation datasets via attributes: - data_conf.training_data: TrainingDataHandler instance with SEDs and positions - data_conf.test_data: TestDataHandler instance with validation SEDs and positions
checkpoint_dir (str) – Directory where model checkpoints will be saved during training.
optimizer_dir (str) – Directory where the optimizer history (as a NumPy .npy file) will be stored.
psf_model_dir (str) – Directory where the final trained PSF model weights will be saved per cycle.
- Returns:
None
Side Effects
————
Saves model weights to psf_model_dir per training cycle (or final one if not all saved)
Saves optimizer histories to optimizer_dir
- Logs cycle information and time durations