Source code for wf_psf.training.train

"""Train.

A module which defines the classes and methods
to manage training of the psf model.

:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>

"""

import gc
import numpy as np
import time
import tensorflow as tf
import logging
from wf_psf.psf_models import psf_models
import wf_psf.training.train_utils as train_utils
from wf_psf.utils.optimizer import get_optimizer

logger = logging.getLogger(__name__)


[docs] def get_gpu_info(): """Get GPU Information. A function to return GPU device name. Returns ------- device_name: str Name of GPU device """ device_name = tf.test.gpu_device_name() return device_name
[docs] def setup_training(): """Set up Training. A function to setup training. """ device_name = get_gpu_info() logger.info(f"Found GPU at: {device_name}")
[docs] def filepath_chkp_callback( checkpoint_dir: str, model_name: str, id_name: str, current_cycle: int ) -> str: """ Generate a file path for a checkpoint callback. Parameters ---------- checkpoint_dir : str The directory where the checkpoint will be saved. model_name : str The name of the model. id_name : str The unique identifier for the model instance. current_cycle : int The current cycle number. Returns ------- str A string representing the full file path for the checkpoint callback. """ return ( checkpoint_dir + "/checkpoint_callback_" + model_name + id_name + "_cycle" + str(current_cycle) )
[docs] class TrainingParamsHandler: """Training Parameters Handler. A class to handle training parameters accessed: Parameters ---------- training_params: Recursive Namespace object Recursive Namespace object containing training input parameters """ def __init__(self, training_params): self.training_params = training_params self.run_id_name = self.model_name + self.id_name self.optimizer_params = {} @property def id_name(self): """ID Name. Set unique ID name. Returns ------- id_name: str A unique ID. """ return self.training_params.id_name @property def model_name(self): """PSF Model Name. Set model_name. Returns ------- model_name: str Name of PSF model """ return self.training_params.model_params.model_name @property def model_params(self): """PSF Model Params. Set PSF model training parameters Returns ------- model_params: Recursive Namespace object Recursive Namespace object storing PSF model parameters """ return self.training_params.model_params @property def training_hparams(self): """Training Hyperparameters. Set training hyperparameters Returns ------- training_hparams: Recursive Namespace object Recursive Namespace object storing training hyper parameters """ return self.training_params.training_hparams @property def multi_cycle_params(self): """Training Multi Cycle Parameters. Set training multi cycle parameters Returns ------- multi_cycle_params: Recursive Namespace object Recursive Namespace object storing training multi-cycle parameters """ return self.training_hparams.multi_cycle_params @property def total_cycles(self): """Total Number of Cycles. Set total number of cycles for training. Returns ------- total_cycles: int Total number of cycles for training """ return self.multi_cycle_params.total_cycles @property def n_epochs_params(self): """Number of Epochs for Parametric PSF model. Set the number of epochs for training parametric PSF model. Returns ------- n_epochs_params: list List of number of epochs for training parametric PSF model. """ return self.multi_cycle_params.n_epochs_params @property def n_epochs_non_params(self): """Number of Epochs for Non-parametric PSF model. Set the number of epochs for training non-parametric PSF model. Returns ------- n_epochs_non_params: list List of number of epochs for training non-parametric PSF model. """ return self.multi_cycle_params.n_epochs_non_params @property def learning_rate_params(self): """Parametric Model Learning Rate. Set learning rate for parametric PSF model. Returns ------- learning_rate_params: list List containing learning rate for parametric PSF model """ return self.multi_cycle_params.learning_rate_params @property def learning_rate_non_params(self): """Non-parametric Model Learning Rate. Set learning rate for non-parametric PSF model. Returns ------- learning_rate_non_params: list List containing learning rate for non-parametric PSF model """ return self.multi_cycle_params.learning_rate_non_params def _prepare_callbacks( self, checkpoint_dir, current_cycle, monitor="mean_squared_error" ): """Prepare Callbacks. A function to prepare to save the model as a callback. Parameters ---------- checkpoint_dir: str Checkpoint directory current_cycle: int Integer representing the current cycle Returns ------- keras.callbacks.ModelCheckpoint class Class to save the Keras model or model weights at some frequency """ # ----------------------------------------------------- logger.info("Preparing Keras model callback...") return tf.keras.callbacks.ModelCheckpoint( filepath_chkp_callback( checkpoint_dir, self.model_name, self.id_name, current_cycle ), monitor=monitor, verbose=1, save_best_only=True, save_weights_only=True, mode="min", save_freq="epoch", options=None, )
[docs] def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. Parameters ---------- training_handler: TrainingParamsHandler TrainingParamsHandler object containing training parameters data_conf: object Data configuration object containing training and test data Returns ------- loss: tf.keras.losses.Loss Loss function to be used for training param_metrics: list List of metrics for the parametric model non_param_metrics: list List of metrics for the non-parametric model monitor: str Metric to monitor for saving the model outputs: tf.Tensor Tensor containing the outputs for training output_val: tf.Tensor Tensor containing the outputs for validation """ if training_handler.training_hparams.loss == "mask_mse": loss = train_utils.MaskedMeanSquaredError() monitor = "loss" param_metrics = [train_utils.MaskedMeanSquaredErrorMetric()] non_param_metrics = [train_utils.MaskedMeanSquaredErrorMetric()] outputs = tf.stack( [ data_conf.training_data.dataset["noisy_stars"], data_conf.training_data.dataset["masks"], ], axis=-1, ) output_val = tf.stack( [ data_conf.test_data.dataset["stars"], data_conf.test_data.dataset["masks"], ], axis=-1, ) else: loss = tf.keras.losses.MeanSquaredError() monitor = "mean_squared_error" param_metrics = [tf.keras.metrics.MeanSquaredError()] non_param_metrics = [tf.keras.metrics.MeanSquaredError()] outputs = data_conf.training_data.dataset["noisy_stars"] output_val = data_conf.test_data.dataset["stars"] return loss, param_metrics, non_param_metrics, monitor, outputs, output_val
[docs] def train( training_params, data_conf, checkpoint_dir, optimizer_dir, psf_model_dir, ): """ Train the PSF model over one or more parametric and non-parametric training cycles. This function manages multi-cycle training of a parametric + non-parametric PSF model, including initialization, loss/metric configuration, optimizer setup, model checkpointing, and optional projection or resetting of non-parametric features. Each cycle can include both parametric and non-parametric training stages, and training history is saved for each. Parameters ---------- training_params : RecursiveNamespace Contains all training configuration parameters, including: - learning rates per cycle - number of epochs per component per cycle - model type and training behavior flags - multi-cycle definitions and callbacks data_conf : object Contains training and validation datasets via attributes: - data_conf.training_data: TrainingDataHandler instance with SEDs and positions - data_conf.test_data: TestDataHandler instance with validation SEDs and positions checkpoint_dir : str Directory where model checkpoints will be saved during training. optimizer_dir : str Directory where the optimizer history (as a NumPy .npy file) will be stored. psf_model_dir : str Directory where the final trained PSF model weights will be saved per cycle. Returns ------- None Side Effects ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations """ # Start measuring elapsed time starting_time = time.time() training_handler = TrainingParamsHandler(training_params) psf_model = psf_models.get_psf_model( training_handler.model_params, training_handler.training_hparams, data_conf, ) logger.info(f"PSF Model class: `{training_handler.model_name}` initialized...") # Model Training # ----------------------------------------------------- # Save optimisation history in the saving dict saving_optim_hist = {} # Perform all the necessary cycles current_cycle = 0 while training_handler.total_cycles > current_cycle: current_cycle += 1 # Instantiate fresh loss, monitor, and independent metric objects per training phase (param / non-param) loss, param_metrics, non_param_metrics, monitor, outputs, output_val = ( get_loss_metrics_monitor_and_outputs(training_handler, data_conf) ) # If projected learning is enabled project DD_features. if hasattr(psf_model, "project_dd_features") and psf_model.project_dd_features: if current_cycle > 1: psf_model.project_DD_features( psf_model.zernike_maps ) # make this a callable function logger.info( "Projected non-parametric DD features onto the parametric model." ) if hasattr(psf_model, "reset_dd_features") and psf_model.reset_dd_features: psf_model.tf_np_poly_opd.init_vars() logger.info("DataDriven features were reset to random initialisation.") # Prepare the saving callback # Prepare to save the model as a callback # ----------------------------------------------------- model_chkp_callback = training_handler._prepare_callbacks( checkpoint_dir, current_cycle, monitor=monitor ) # Prepare the optimizers param_optim = get_optimizer( optimizer_config=training_handler.training_hparams.optimizer, learning_rate=training_handler.learning_rate_params[current_cycle - 1], ) non_param_optim = get_optimizer( optimizer_config=training_handler.training_hparams.optimizer, learning_rate=training_handler.learning_rate_non_params[current_cycle - 1] ) logger.info(f"Starting cycle {current_cycle}..") start_cycle = time.time() # Compute training per cycle ( psf_model, hist_param, hist_non_param, ) = train_utils.general_train_cycle( psf_model, inputs=[ data_conf.training_data.dataset["positions"], data_conf.training_data.sed_data, ], outputs=outputs, validation_data=( [ data_conf.test_data.dataset["positions"], data_conf.test_data.sed_data, ], output_val, ), batch_size=training_handler.training_hparams.batch_size, learning_rate_param=training_handler.learning_rate_params[ current_cycle - 1 ], learning_rate_non_param=training_handler.learning_rate_non_params[ current_cycle - 1 ], n_epochs_param=training_handler.n_epochs_params[current_cycle - 1], n_epochs_non_param=training_handler.n_epochs_non_params[current_cycle - 1], param_optim=param_optim, non_param_optim=non_param_optim, param_loss=loss, non_param_loss=loss, param_metrics=param_metrics, non_param_metrics=non_param_metrics, param_callback=None, non_param_callback=None, general_callback=[model_chkp_callback], first_run=True if current_cycle == 1 else False, cycle_def=training_handler.multi_cycle_params.cycle_def, use_sample_weights=training_handler.model_params.use_sample_weights, apply_sigmoid=training_handler.model_params.sample_weights_sigmoid.apply_sigmoid, sigmoid_max_val=training_handler.model_params.sample_weights_sigmoid.sigmoid_max_val, sigmoid_power_k=training_handler.model_params.sample_weights_sigmoid.sigmoid_power_k, verbose=2, ) # Save the weights at the end of the nth cycle if training_handler.multi_cycle_params.save_all_cycles: psf_model.save_weights( psf_model_dir + "/psf_model_" + training_handler.model_name + training_handler.id_name + "_cycle" + str(current_cycle) ) end_cycle = time.time() logger.info(f"Cycle{current_cycle} elapsed time: {end_cycle - start_cycle}") # Save optimisation history in the saving dict if ( hasattr(psf_model, "save_optim_history_param") and psf_model.save_optim_history_param ): saving_optim_hist[f"param_cycle{current_cycle}"] = hist_param.history if ( hasattr(psf_model, "save_optim_history_nonparam") and psf_model.save_optim_history_nonparam ): saving_optim_hist[f"nonparam_cycle{current_cycle}"] = hist_non_param.history # Save last cycle if no cycles were saved if not training_handler.multi_cycle_params.save_all_cycles: psf_model.save_weights( psf_model_dir + "/psf_model_" + training_handler.model_name + training_handler.id_name + "_cycle" + str(current_cycle) ) # Save optimisation history dictionary np.save( optimizer_dir + "/optim_hist_" + training_handler.model_name + training_handler.id_name + ".npy", saving_optim_hist, ) # Print final time final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") # Clean up memory del psf_model gc.collect() tf.keras.backend.clear_session()