wf_psf.training.train_utils module
- class wf_psf.training.train_utils.L1ParamScheduler(l1_schedule_rule)[source]
Bases:
Callback
L1 rate scheduler which sets the L1 rate according to schedule.
- Parameters:
l1_schedule_rule (function) –
a function that takes an epoch index (integer, indexed from 0) and current l1_rate
as inputs and returns a new l1_rate as output (float).
Methods
on_batch_begin
(batch[, logs])A backwards compatibility alias for on_train_batch_begin.
on_batch_end
(batch[, logs])A backwards compatibility alias for on_train_batch_end.
on_epoch_begin
(epoch[, logs])Called at the start of an epoch.
on_epoch_end
(epoch[, logs])Called at the end of an epoch.
on_predict_batch_begin
(batch[, logs])Called at the beginning of a batch in predict methods.
on_predict_batch_end
(batch[, logs])Called at the end of a batch in predict methods.
on_predict_begin
([logs])Called at the beginning of prediction.
on_predict_end
([logs])Called at the end of prediction.
on_test_batch_begin
(batch[, logs])Called at the beginning of a batch in evaluate methods.
on_test_batch_end
(batch[, logs])Called at the end of a batch in evaluate methods.
on_test_begin
([logs])Called at the beginning of evaluation or validation.
on_test_end
([logs])Called at the end of evaluation or validation.
on_train_batch_begin
(batch[, logs])Called at the beginning of a training batch in fit methods.
on_train_batch_end
(batch[, logs])Called at the end of a training batch in fit methods.
on_train_begin
([logs])Called at the beginning of training.
on_train_end
([logs])Called at the end of training.
set_model
set_params
- on_epoch_begin(epoch, logs=None)[source]
Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.
- wf_psf.training.train_utils.general_train_cycle(tf_semiparam_field, inputs, outputs, validation_data, batch_size, learning_rate_param, learning_rate_non_param, n_epochs_param, n_epochs_non_param, param_optim=None, non_param_optim=None, param_loss=None, non_param_loss=None, param_metrics=None, non_param_metrics=None, param_callback=None, non_param_callback=None, general_callback=None, first_run=False, cycle_def='complete', use_sample_weights=False, verbose=1)[source]
Function to do a BCD iteration on the model.
Define the model optimisation.
For the parametric part we are using:
learning_rate_param = 1e-2
,n_epochs_param = 20
. For the non-parametric part we are using:learning_rate_non_param = 1.0
,n_epochs_non_param = 100
.- Parameters:
tf_semiparam_field (tf.keras.Model) – The model to be trained.
inputs (Tensor or list of tensors) – Inputs used for Model.fit()
outputs (Tensor) – Outputs used for Model.fit()
validation_data (Tuple) – Validation test data used for Model.fit(). Tuple of input, output validation data
batch_size (int) – Batch size for the training.
learning_rate_param (float) – Learning rate for the parametric part
learning_rate_non_param (float) – Learning rate for the non-parametric part
n_epochs_param (int) – Number of epochs for the parametric part
n_epochs_non_param (int) – Number of epochs for the non-parametric part
param_optim (Tensorflow optimizer) – Optimizer for the parametric part. Optional, default is the Adam optimizer
non_param_optim (Tensorflow optimizer) – Optimizer for the non-parametric part. Optional, default is the Adam optimizer
param_loss (Tensorflow loss) – Loss function for the parametric part. Optional, default is the MeanSquaredError() loss
non_param_loss (Tensorflow loss) – Loss function for the non-parametric part. Optional, default is the MeanSquaredError() loss
param_metrics (Tensorflow metrics) – Metrics for the parametric part. Optional, default is the MeanSquaredError() metric
non_param_metrics (Tensorflow metrics) – Metrics for the non-parametric part. Optional, default is the MeanSquaredError() metric
param_callback (Tensorflow callback) – Callback for the parametric part only. Optional, default is no callback
non_param_callback (Tensorflow callback) – Callback for the non-parametric part only. Optional, default is no callback
general_callback (Tensorflow callback) – Callback shared for both the parametric and non-parametric parts. Optional, default is no callback
first_run (bool) – If True, it is the first iteration of the model training. The Non-parametric part is not considered in the first parametric training.
cycle_def (string) – Train cycle definition. It can be: parametric, non-parametric, complete. Default is complete.
use_sample_weights (bool) – If True, the sample weights are used for the training. The sample weights are computed as the inverse noise estimated variance
verbose (int) – Verbosity mode used for the training procedure. If a log of the training is being saved, verbose=2 is recommended.
- Returns:
tf_semiparam_field (tf.keras.Model) – Trained Tensorflow model.
hist_param (Tensorflow’s History object) – History of the parametric training.
hist_non_param (Tensorflow’s History object) – History of the non-parametric training.
- wf_psf.training.train_utils.param_train_cycle(tf_semiparam_field, inputs, outputs, validation_data, batch_size, learning_rate, n_epochs, param_optim=None, param_loss=None, param_metrics=None, param_callback=None, general_callback=None, use_sample_weights=False, verbose=1)[source]
Training cycle for parametric model.