wf_psf.training.train_utils

Training utilities for the PSF model.

This module contains helper functions and utilities related to the training process for the PSF model. These functions help with managing training cycles, callbacks, and related operations.

Authors: Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>,

Jennifer Pollack <jennifer.pollack@cea.fr>

Functions

calculate_sample_weights(outputs, ...[, ...])

Calculate sample weights based on image noise standard deviation.

configure_optimizer_and_loss(learning_rate)

Configure and return the optimizer, loss function, and metrics for model training.

general_train_cycle(psf_model, inputs, ...)

Perform a Bi-Cycle Descent (BCD) training iteration on a semi-parametric model.

get_callbacks(callback1, callback2)

Combine two callback lists into one.

l1_schedule_rule(epoch_n, l1_rate)

Schedule the L1 rate based on the epoch number.

masked_mse(y_true, y_pred, mask[, sample_weight])

Compute the mean squared error over the masked regions.

train_cycle_part(psf_model, inputs, outputs, ...)

Train either the parametric or non-parametric part of the PSF model using the specified parameters.

Classes

L1ParamScheduler(l1_schedule_rule)

L1 rate scheduler that adjusts the L1 rate during training according to a specified schedule.

MaskedMeanSquaredError([name])

Computes the masked mean squared error (MSE) loss between predictions and targets.

MaskedMeanSquaredErrorMetric(*args, **kwargs)

A custom metric class for computing the masked mean squared error (MSE).

class wf_psf.training.train_utils.L1ParamScheduler(l1_schedule_rule)[source]

Bases: Callback

L1 rate scheduler that adjusts the L1 rate during training according to a specified schedule.

This callback modifies the L1 regularization rate at each epoch based on the given scheduling function. The function takes the epoch index and the current L1 rate as inputs, and it outputs the updated L1 rate.

Parameters:

l1_schedule_rule (function) –

A function that defines how to update the L1 rate. The function should take two arguments: - epoch (int): The current epoch index, starting from 0. - current_l1_rate (float): The L1 rate at the current epoch.

The function should return a new L1 rate (float) to be applied at the next epoch.

Methods

on_batch_begin(batch[, logs])

A backwards compatibility alias for on_train_batch_begin.

on_batch_end(batch[, logs])

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch[, logs])

Execute callback function at the beginning of each epoch to adjust the L1 rate.

on_epoch_end(epoch[, logs])

Called at the end of an epoch.

on_predict_batch_begin(batch[, logs])

Called at the beginning of a batch in predict methods.

on_predict_batch_end(batch[, logs])

Called at the end of a batch in predict methods.

on_predict_begin([logs])

Called at the beginning of prediction.

on_predict_end([logs])

Called at the end of prediction.

on_test_batch_begin(batch[, logs])

Called at the beginning of a batch in evaluate methods.

on_test_batch_end(batch[, logs])

Called at the end of a batch in evaluate methods.

on_test_begin([logs])

Called at the beginning of evaluation or validation.

on_test_end([logs])

Called at the end of evaluation or validation.

on_train_batch_begin(batch[, logs])

Called at the beginning of a training batch in fit methods.

on_train_batch_end(batch[, logs])

Called at the end of a training batch in fit methods.

on_train_begin([logs])

Called at the beginning of training.

on_train_end([logs])

Called at the end of training.

set_model

set_params

on_epoch_begin(epoch, logs=None)[source]

Execute callback function at the beginning of each epoch to adjust the L1 rate.

This function gets the current L1 rate from the model’s optimizer, computes the new scheduled L1 rate using the l1_schedule_rule function, and sets it back to the model’s optimizer.

Parameters:
  • epoch (int) – The current epoch index, starting from 0.

  • logs (dict, optional) – A dictionary containing logs for the current epoch (default is None).

class wf_psf.training.train_utils.MaskedMeanSquaredError(name: str = 'masked_mean_squared_error', **kwargs)[source]

Bases: Loss

Computes the masked mean squared error (MSE) loss between predictions and targets.

This loss function assumes that y_true has two components in the last axis: - y_true[…, 0]: the target values. - y_true[…, 1]: a mask in [0, 1] where:

  • 1 means the pixel is included in the loss.

  • 0 means the pixel is ignored.

  • Values in (0,1) are treated as weights for partial contribution.

Methods

__call__(y_true, y_pred[, sample_weight])

Invoke the loss computation with support for different shapes of inputs.

call(y_true, y_pred[, sample_weight])

Compute the masked mean squared error loss.

from_config(config)

Instantiates a Loss from its config (output of get_config()).

get_config()

Returns the config dictionary for a Loss instance.

call(y_true: Tensor, y_pred: Tensor, sample_weight: Tensor | None = None) Tensor[source]

Compute the masked mean squared error loss.

Parameters:
  • y_true (tf.Tensor) – Tensor of shape (batch, height, width, 2) containing the ground truth values and the mask.

  • y_pred (tf.Tensor) – Tensor of shape (batch, height, width) containing the predicted values.

  • sample_weight (tf.Tensor, optional) – Optional sample weights of shape (batch,). Broadcast over spatial dims.

Returns:

Scalar tensor representing the mean squared error over masked pixels.

Return type:

tf.Tensor

class wf_psf.training.train_utils.MaskedMeanSquaredErrorMetric(*args, **kwargs)[source]

Bases: Metric

A custom metric class for computing the masked mean squared error (MSE).

This metric computes the mean squared error over the masked regions of the true values and predictions. A mask is applied such that a mask value of 1 excludes the pixel and 0 includes the pixel in the error computation. The metric is updated after every batch and returns the average masked MSE after processing the dataset.

Attributes:
activity_regularizer

Optional regularizer function for the output of this layer.

compute_dtype

The dtype of the layer’s computations.

dtype

The dtype of the layer weights.

dtype_policy

The dtype policy associated with this layer.

dynamic

Whether the layer is dynamic (eager-only); set in the constructor.

inbound_nodes

Return Functional API nodes upstream of this layer.

input

Retrieves the input tensor(s) of a layer.

input_mask

Retrieves the input mask tensor(s) of a layer.

input_shape

Retrieves the input shape(s) of a layer.

input_spec

InputSpec instance(s) describing the input format for this layer.

losses

List of losses added using the add_loss() API.

metrics

List of metrics added using the add_metric() API.

name

Name of the layer (string), set in the constructor.

name_scope

Returns a tf.name_scope instance for this class.

non_trainable_variables

Sequence of non-trainable variables owned by this module and its submodules.

non_trainable_weights

List of all non-trainable weights tracked by this layer.

outbound_nodes

Return Functional API nodes downstream of this layer.

output

Retrieves the output tensor(s) of a layer.

output_mask

Retrieves the output mask tensor(s) of a layer.

output_shape

Retrieves the output shape(s) of a layer.

stateful
submodules

Sequence of all sub-modules.

supports_masking

Whether this layer supports computing a mask using compute_mask.

trainable
trainable_variables

Sequence of trainable variables owned by this module and its submodules.

trainable_weights

List of all trainable weights tracked by this layer.

updates
variable_dtype

Alias of Layer.dtype, the dtype of the weights.

variables

Returns the list of all layer variables/weights.

weights

Returns the list of all layer variables/weights.

Methods

__call__(*args, **kwargs)

Accumulates statistics and then computes metric result value.

add_loss(losses, **kwargs)

Add loss tensor(s), potentially dependent on layer inputs.

add_metric(value[, name])

Adds metric tensor to the layer.

add_update(updates)

Add update op(s), potentially dependent on layer inputs.

add_variable(*args, **kwargs)

Deprecated, do NOT use! Alias for add_weight.

add_weight(name[, shape, aggregation, ...])

Adds state variable.

build(input_shape)

Creates the variables of the layer (optional, for subclass implementers).

call(inputs, *args, **kwargs)

This is where the layer's logic lives.

compute_mask(inputs[, mask])

Computes an output mask tensor.

compute_output_shape(input_shape)

Computes the output shape of the layer.

compute_output_signature(input_signature)

Compute the output tensor signature of the layer based on the inputs.

count_params()

Count the total number of scalars composing the weights.

finalize_state()

Finalizes the layers state after updating layer weights.

from_config(config)

Creates a layer from its config.

get_config()

Returns the serializable config of the metric.

get_input_at(node_index)

Retrieves the input tensor(s) of a layer at a given node.

get_input_mask_at(node_index)

Retrieves the input mask tensor(s) of a layer at a given node.

get_input_shape_at(node_index)

Retrieves the input shape(s) of a layer at a given node.

get_output_at(node_index)

Retrieves the output tensor(s) of a layer at a given node.

get_output_mask_at(node_index)

Retrieves the output mask tensor(s) of a layer at a given node.

get_output_shape_at(node_index)

Retrieves the output shape(s) of a layer at a given node.

get_weights()

Returns the current weights of the layer, as NumPy arrays.

merge_state(metrics)

Merges the state from one or more metrics.

reset_state()

Reset the state of the metric, clearing the accumulated total loss and batch count.

result()

Compute and return the current masked MSE value, averaged over all batches.

set_weights(weights)

Sets the weights of the layer, from NumPy arrays.

update_state(y_true, y_pred[, sample_weight])

Update the state of the metric with the true values, predictions, and optional sample weights.

with_name_scope(method)

Decorator to automatically enter the module name scope.

reset_states

reset_state() None[source]

Reset the state of the metric, clearing the accumulated total loss and batch count.

This method is typically called at the start of a new evaluation or after a new epoch.

Return type:

None

result() Tensor[source]

Compute and return the current masked MSE value, averaged over all batches.

Returns:

The current masked MSE value (average loss per batch).

Return type:

tf.Tensor

update_state(y_true: Tensor, y_pred: Tensor, sample_weight: Tensor | None = None) None[source]

Update the state of the metric with the true values, predictions, and optional sample weights.

This method calculates the masked MSE loss for the given batch and accumulates the total loss and batch count.

Parameters:
  • y_true (tf.Tensor) – True values with shape (batch_size, height, width, channels), where the last dimension contains both the target values and the mask.

  • y_pred (tf.Tensor) – Predicted values with shape (batch_size, height, width, channels).

  • sample_weight (tf.Tensor, optional) – Sample weights for each instance in the batch, with shape (batch_size,). If not provided, all instances are treated equally.

Return type:

None

wf_psf.training.train_utils.calculate_sample_weights(outputs: ndarray, use_sample_weights: bool, loss: str | Callable | None, apply_sigmoid: bool = False, sigmoid_max_val: float = 5.0, sigmoid_power_k: float = 1.0) ndarray | None[source]

Calculate sample weights based on image noise standard deviation.

The function computes sample weights by estimating the noise standard deviation for each image, calculating the inverse variance, and then normalizing the weights by dividing by the median.

Parameters:
  • outputs (np.ndarray) – A 3D array of shape (batch_size, height, width) representing images, where the first dimension is the batch size and the next two dimensions are the image height and width.

  • use_sample_weights (bool) – Flag indicating whether to compute sample weights. If True, sample weights will be computed based on the image noise.

  • loss (str, callable, optional) – The loss function used for training. If the loss name is “masked_mean_squared_error”, the function will calculate the noise standard deviation for masked images.

  • apply_sigmoid (bool, optional) – Flag indicating whether to apply a generalized sigmoid function to the sample weights. Default is True.

  • sigmoid_max_val (float, optional) – The maximum value for the sigmoid function. Default is 5.0.

  • sigmoid_power_k (float, optional) – The power parameter for the sigmoid function. Default is 1.0. This parameter controls the steepness of the sigmoid curve. Higher values make the curve steeper.

Returns:

An array of sample weights, or None if use_sample_weights is False.

Return type:

np.ndarray or None

wf_psf.training.train_utils.configure_optimizer_and_loss(learning_rate: float, optimizer: Callable | None = None, loss: Callable | None = None, metrics: list[Callable] | None = None) tuple[Callable, Callable, list[Callable]][source]

Configure and return the optimizer, loss function, and metrics for model training.

This function configures the optimizer, loss function, and metrics for either the parametric or non-parametric model components. If no optimizer, loss, or metrics are provided, default values are used.

Parameters:
  • learning_rate (float) – The learning rate to be used by the optimizer.

  • optimizer (callable, optional) – A function or object used to initialize the optimizer (e.g., tf.keras.optimizers.Adam). If None, the default Adam optimizer with the specified learning rate is used.

  • loss (callable, optional) – The loss function to be used during training (e.g., tf.keras.losses.MeanSquaredError). If None, the default Mean Squared Error loss is used.

  • metrics (list of callable, optional) – A list of metric functions to evaluate during training (e.g., tf.keras.metrics.MeanSquaredError). If None, the default metric MeanSquaredError is used.

Returns:

  • optimizer (callable) – The optimizer function or object configured for training.

  • loss (callable) – The loss function configured for training.

  • metrics (list of callable) – The list of metrics to be used for evaluating the model.

wf_psf.training.train_utils.general_train_cycle(psf_model, inputs, outputs, validation_data, batch_size, learning_rate_param, learning_rate_non_param, n_epochs_param, n_epochs_non_param, param_optim=None, non_param_optim=None, param_loss=None, non_param_loss=None, param_metrics=None, non_param_metrics=None, param_callback=None, non_param_callback=None, general_callback=None, first_run=False, cycle_def='complete', use_sample_weights=False, apply_sigmoid=True, sigmoid_max_val=5.0, sigmoid_power_k=1.0, verbose=1)[source]

Perform a Bi-Cycle Descent (BCD) training iteration on a semi-parametric model.

The function alternates between optimizing the parametric and/or non-parametric parts of the model across specified training cycles. Each part of the model can be trained individually or together depending on the cycle_def parameter.

For the parametric part: - Default learning rate: learning_rate_param = 1e-2 - Default epochs: n_epochs_param = 20

For the non-parametric part: - Default learning rate: learning_rate_non_param = 1.0 - Default epochs: n_epochs_non_param = 100

Parameters:
  • psf_model (tf.keras.Model) – A TensorFlow model representing the PSF (Point Spread Function), which may consist of both parametric and non-parametric components, or an individual component. These components are partitioned for training, with each part addressing different aspects of the PSF.

  • inputs (Tensor or list of tensors) – Input data for training (Model.fit()).

  • outputs (Tensor) – Output data for training (Model.fit()).

  • validation_data (Tuple) – Validation data used for model evaluation during training. (input_data, output_data).

  • batch_size (int) – The batch size for the training.

  • learning_rate_param (float) – Learning rate for the parametric part of the PSF model.

  • learning_rate_non_param (float) – Learning rate for the non-parametric part of the PSF model.

  • n_epochs_param (int) – Number of epochs to train the parametric part.

  • n_epochs_non_param (int) – Number of epochs to train the non-parametric part.

  • param_optim (tf.keras.optimizers.Optimizer, optional) – Optimizer for the parametric part. Defaults to Adam if not provided.

  • non_param_optim (tf.keras.optimizers.Optimizer, optional) – Optimizer for the non-parametric part. Defaults to Adam if not provided.

  • param_loss (tf.keras.losses.Loss, optional) – Loss function for the parametric part. Defaults to the MeanSquaredError().

  • non_param_loss (tf.keras.losses.Loss, optional) – Loss function for the non-parametric part. Defaults to MeanSquaredError().

  • param_metrics (list of tf.keras.metrics.Metric, optional) – List of metrics for the parametric part. Defaults to MeanSquaredError().

  • non_param_metrics (list of tf.keras.metrics.Metric, optional) – List of metrics for the non-parametric part. Defaults to MeanSquaredError().

  • param_callback (list of tf.keras.callbacks.Callback, optional) – Callback for the parametric part only. Defaults to no callback.

  • non_param_callback (list of tf.keras.callbacks.Callback, optional) – Callback for the non-parametric part only. Defaults to no callback.

  • general_callback (list of tf.keras.callbacks.Callback, optional) – Callback shared between both the parametric and non-parametric parts. Defaults to no callback.

  • first_run (bool, optional) – If True, the first iteration of training is assumed, and the non-parametric part is not considered during the parametric training. Default is False.

  • cycle_def (str, optional) – Defines the training cycle: parametric, non-parametric, complete, only-parametric, or only-non-parametric. The complete cycle trains both parts, while the others train only the specified part (both parametric and non-parametric). Default is complete.

  • use_sample_weights (bool, optional) – If True, sample weights are used in training. Sample weights are computed based on estimated noise variance. Default is False.

  • apply_sigmoid (bool, optional) – If True, a generalized sigmoid function is applied to the sample weights. Default is True.

  • sigmoid_max_val (float, optional) – The maximum value for the sigmoid function. Default is 5.0.

  • sigmoid_power_k (float, optional) – The power parameter for the sigmoid function. Default is 1.0. This parameter controls the steepness of the sigmoid curve.

  • verbose (int, optional) – Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. Default is 1.

Returns:

  • psf_model (tf.keras.Model) – The trained PSF model.

  • hist_param (tf.keras.callbacks.History) – History object for the parametric training.

  • hist_non_param (tf.keras.callbacks.History) – History object for the non-parametric training.

wf_psf.training.train_utils.get_callbacks(callback1, callback2)[source]

Combine two callback lists into one.

If both are None, returns None. If one is None, returns the other. Otherwise, combines both lists.

Parameters:
  • callback1 (list of tf.keras.callbacks.Callback or None) – The first callback list (e.g., parametric or non-parametric).

  • callback2 (list of tf.keras.callbacks.Callback or None) – The second callback list (e.g., general callback).

Returns:

The combined list of callbacks or None.

Return type:

list of tf.keras.callbacks.Callback or None

wf_psf.training.train_utils.l1_schedule_rule(epoch_n: int, l1_rate: float) float[source]

Schedule the L1 rate based on the epoch number.

If the current epoch is a multiple of 10 (except for the first epoch), the L1 rate is halved. Otherwise, the L1 rate remains unchanged.

Parameters:
  • epoch_n (int) – The current epoch number, where the epoch index starts from 0.

  • l1_rate (float) – The current L1 regularization rate.

Returns:

The updated L1 rate for the given epoch.

Return type:

float

wf_psf.training.train_utils.masked_mse(y_true: Tensor, y_pred: Tensor, mask: Tensor, sample_weight: Tensor | None = None) Tensor[source]

Compute the mean squared error over the masked regions.

Parameters:
  • y_true (tf.Tensor) – True values with shape (batch, height, width).

  • y_pred (tf.Tensor) – Predicted values with shape (batch, height, width).

  • mask (tf.Tensor) – A mask to apply, which can contain float values in [0,1]. - 0 means to include the pixel. - 1 means to ignore the pixel. - Values in (0,1) act as weights for partial consideration.

  • sample_weight (tf.Tensor, optional) – Sample weights for each image in the batch, with shape (batch,). If provided, it is broadcasted over the spatial dimensions.

Returns:

The mean squared error computed over the masked regions.

Return type:

tf.Tensor

wf_psf.training.train_utils.train_cycle_part(psf_model: Model, inputs: Tensor, outputs: Tensor, batch_size: int, epochs: int, optimizer: Optimizer, loss: Callable, metrics: list[Callable], validation_data: tuple[Tensor, Tensor] | None = None, callbacks: list[Callable] | None = None, sample_weight: Tensor | None = None, verbose: int = 1, cycle_part: str = 'parametric') Model[source]

Train either the parametric or non-parametric part of the PSF model using the specified parameters. This function trains a single component of the model (either parametric or non-parametric) based on the provided configuration.

Parameters:
  • psf_model (tf.keras.Model) – A TensorFlow model representing the PSF (Point Spread Function), which consists of either a parametric or a non-parametric component.

  • inputs (tf.Tensor) – Input data for training the model. Expected to be a tensor with the shape of the input batch.

  • outputs (tf.Tensor) – Target output data for training the model. Expected to match the shape of inputs.

  • batch_size (int) – The number of samples per batch during training.

  • epochs (int) – The number of epochs to train the model.

  • optimizer (tf.keras.optimizers.Optimizer) – The optimizer used for training the model (e.g., Adam, SGD).

  • loss (Callable) – The loss function used for training the model. Typically a callable like tf.keras.losses.MeanSquaredError().

  • metrics (list of Callable) – List of metrics to monitor during training. Each element should be a callable metric (e.g., accuracy, precision).

  • validation_data (tuple of (tf.Tensor, tf.Tensor), optional) – Tuple of input and output tensors to evaluate the model during training. Default is None.

  • callbacks (list of Callable, optional) – List of callbacks to apply during training, such as tf.keras.callbacks.EarlyStopping. Default is None.

  • sample_weight (tf.Tensor, optional) – Weights for the samples during training. Default is None.

  • verbose (int, optional) – Verbosity mode (0, 1, or 2). Default is 1.

  • cycle_part (str, optional) – Specifies which part of the model to train (“parametric” or “non-parametric”). Default is “parametric”.

Returns:

The trained TensorFlow model after completing the specified number of epochs.

Return type:

tf.keras.Model

Notes

This function trains the model based on the provided cycle_part. If cycle_part is set to “parametric”, the function assumes the model is being trained in a parametric setting, while “non-parametric” indicates the training of a non-parametric part. The model is compile using the compile_PSF_model function before fitting.

Examples

model = train_cycle_part(

psf_model=model, inputs=train_inputs, outputs=train_outputs, batch_size=32, epochs=10, optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.MeanSquaredError(), metrics=[tf.keras.metrics.MeanAbsoluteError()], validation_data=(val_inputs, val_outputs), callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)], sample_weight=None, verbose=1

)