wf_psf.training.train_utils module

class wf_psf.training.train_utils.L1ParamScheduler(l1_schedule_rule)[source]

Bases: Callback

L1 rate scheduler which sets the L1 rate according to schedule.

Parameters:

l1_schedule_rule (function) –

a function that takes an epoch index (integer, indexed from 0) and current l1_rate

as inputs and returns a new l1_rate as output (float).

Methods

on_batch_begin(batch[, logs])

A backwards compatibility alias for on_train_batch_begin.

on_batch_end(batch[, logs])

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch[, logs])

Called at the start of an epoch.

on_epoch_end(epoch[, logs])

Called at the end of an epoch.

on_predict_batch_begin(batch[, logs])

Called at the beginning of a batch in predict methods.

on_predict_batch_end(batch[, logs])

Called at the end of a batch in predict methods.

on_predict_begin([logs])

Called at the beginning of prediction.

on_predict_end([logs])

Called at the end of prediction.

on_test_batch_begin(batch[, logs])

Called at the beginning of a batch in evaluate methods.

on_test_batch_end(batch[, logs])

Called at the end of a batch in evaluate methods.

on_test_begin([logs])

Called at the beginning of evaluation or validation.

on_test_end([logs])

Called at the end of evaluation or validation.

on_train_batch_begin(batch[, logs])

Called at the beginning of a training batch in fit methods.

on_train_batch_end(batch[, logs])

Called at the end of a training batch in fit methods.

on_train_begin([logs])

Called at the beginning of training.

on_train_end([logs])

Called at the end of training.

set_model

set_params

on_epoch_begin(epoch, logs=None)[source]

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters:
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

wf_psf.training.train_utils.general_train_cycle(tf_semiparam_field, inputs, outputs, validation_data, batch_size, learning_rate_param, learning_rate_non_param, n_epochs_param, n_epochs_non_param, param_optim=None, non_param_optim=None, param_loss=None, non_param_loss=None, param_metrics=None, non_param_metrics=None, param_callback=None, non_param_callback=None, general_callback=None, first_run=False, cycle_def='complete', use_sample_weights=False, verbose=1)[source]

Function to do a BCD iteration on the model.

Define the model optimisation.

For the parametric part we are using: learning_rate_param = 1e-2, n_epochs_param = 20. For the non-parametric part we are using: learning_rate_non_param = 1.0, n_epochs_non_param = 100.

Parameters:
  • tf_semiparam_field (tf.keras.Model) – The model to be trained.

  • inputs (Tensor or list of tensors) – Inputs used for Model.fit()

  • outputs (Tensor) – Outputs used for Model.fit()

  • validation_data (Tuple) – Validation test data used for Model.fit(). Tuple of input, output validation data

  • batch_size (int) – Batch size for the training.

  • learning_rate_param (float) – Learning rate for the parametric part

  • learning_rate_non_param (float) – Learning rate for the non-parametric part

  • n_epochs_param (int) – Number of epochs for the parametric part

  • n_epochs_non_param (int) – Number of epochs for the non-parametric part

  • param_optim (Tensorflow optimizer) – Optimizer for the parametric part. Optional, default is the Adam optimizer

  • non_param_optim (Tensorflow optimizer) – Optimizer for the non-parametric part. Optional, default is the Adam optimizer

  • param_loss (Tensorflow loss) – Loss function for the parametric part. Optional, default is the MeanSquaredError() loss

  • non_param_loss (Tensorflow loss) – Loss function for the non-parametric part. Optional, default is the MeanSquaredError() loss

  • param_metrics (Tensorflow metrics) – Metrics for the parametric part. Optional, default is the MeanSquaredError() metric

  • non_param_metrics (Tensorflow metrics) – Metrics for the non-parametric part. Optional, default is the MeanSquaredError() metric

  • param_callback (Tensorflow callback) – Callback for the parametric part only. Optional, default is no callback

  • non_param_callback (Tensorflow callback) – Callback for the non-parametric part only. Optional, default is no callback

  • general_callback (Tensorflow callback) – Callback shared for both the parametric and non-parametric parts. Optional, default is no callback

  • first_run (bool) – If True, it is the first iteration of the model training. The Non-parametric part is not considered in the first parametric training.

  • cycle_def (string) – Train cycle definition. It can be: parametric, non-parametric, complete. Default is complete.

  • use_sample_weights (bool) – If True, the sample weights are used for the training. The sample weights are computed as the inverse noise estimated variance

  • verbose (int) – Verbosity mode used for the training procedure. If a log of the training is being saved, verbose=2 is recommended.

Returns:

  • tf_semiparam_field (tf.keras.Model) – Trained Tensorflow model.

  • hist_param (Tensorflow’s History object) – History of the parametric training.

  • hist_non_param (Tensorflow’s History object) – History of the non-parametric training.

wf_psf.training.train_utils.l1_schedule_rule(epoch_n, l1_rate)[source]
wf_psf.training.train_utils.param_train_cycle(tf_semiparam_field, inputs, outputs, validation_data, batch_size, learning_rate, n_epochs, param_optim=None, param_loss=None, param_metrics=None, param_callback=None, general_callback=None, use_sample_weights=False, verbose=1)[source]

Training cycle for parametric model.