wf_psf.training.train module
Train.
A module which defines the classes and methods to manage training of the psf model.
- Author:
Jennifer Pollack <jennifer.pollack@cea.fr>
- class wf_psf.training.train.TrainingParamsHandler(training_params)[source]
Bases:
object
Training Parameters Handler.
A class to handle training parameters accessed:
- Parameters:
training_params (Recursive Namespace object) – Recursive Namespace object containing training input parameters
- Attributes:
id_name
ID Name.
learning_rate_non_params
Non-parametric Model Learning Rate.
learning_rate_params
Parametric Model Learning Rate.
model_name
PSF Model Name.
model_params
PSF Model Params.
multi_cycle_params
Training Multi Cycle Parameters.
n_epochs_non_params
Number of Epochs for Non-parametric PSF model.
n_epochs_params
Number of Epochs for Parametric PSF model.
total_cycles
Total Number of Cycles.
training_hparams
Training Hyperparameters.
- property learning_rate_non_params
Non-parametric Model Learning Rate.
Set learning rate for non-parametric PSF model.
- Returns:
learning_rate_non_params – List containing learning rate for non-parametric PSF model
- Return type:
- property learning_rate_params
Parametric Model Learning Rate.
Set learning rate for parametric PSF model.
- Returns:
learning_rate_params – List containing learning rate for parametric PSF model
- Return type:
- property model_name
PSF Model Name.
Set model_name.
- Returns:
model_name – Name of PSF model
- Return type:
- property model_params
PSF Model Params.
Set PSF model training parameters
- Returns:
model_params – Recursive Namespace object storing PSF model parameters
- Return type:
Recursive Namespace object
- property multi_cycle_params
Training Multi Cycle Parameters.
Set training multi cycle parameters
- Returns:
multi_cycle_params – Recursive Namespace object storing training multi-cycle parameters
- Return type:
Recursive Namespace object
- property n_epochs_non_params
Number of Epochs for Non-parametric PSF model.
Set the number of epochs for training non-parametric PSF model.
- Returns:
n_epochs_non_params – List of number of epochs for training non-parametric PSF model.
- Return type:
- property n_epochs_params
Number of Epochs for Parametric PSF model.
Set the number of epochs for training parametric PSF model.
- Returns:
n_epochs_params – List of number of epochs for training parametric PSF model.
- Return type:
- property total_cycles
Total Number of Cycles.
Set total number of cycles for training.
- Returns:
total_cycles – Total number of cycles for training
- Return type:
- property training_hparams
Training Hyperparameters.
Set training hyperparameters
- Returns:
training_hparams – Recursive Namespace object storing training hyper parameters
- Return type:
Recursive Namespace object
- wf_psf.training.train.filepath_chkp_callback(checkpoint_dir, model_name, id_name, current_cycle)[source]
- wf_psf.training.train.get_gpu_info()[source]
Get GPU Information.
A function to return GPU device name.
- Returns:
device_name – Name of GPU device
- Return type:
- wf_psf.training.train.train(training_params, training_data, test_data, checkpoint_dir, optimizer_dir, psf_model_dir)[source]
Train.
A function to train the psf model.
- Parameters:
training_params (Recursive Namespace object) – Recursive Namespace object containing the training parameters
training_data (obj) – TrainingDataHandler object containing the training data parameters
test_data (object) – TestDataHandler object containing the test data parameters
checkpoint_dir (str) – Absolute path to checkpoint directory
optimizer_dir (str) – Absolute path to optimizer history directory
psf_model_dir (str) – Absolute path to psf model directory