wf_psf.training.train_utils
Training utilities for the PSF model.
This module contains helper functions and utilities related to the training process for the PSF model. These functions help with managing training cycles, callbacks, and related operations.
- Authors: Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>,
Jennifer Pollack <jennifer.pollack@cea.fr>
Functions
|
Calculate sample weights based on image noise standard deviation. |
|
Configure and return the optimizer, loss function, and metrics for model training. |
|
Perform a Bi-Cycle Descent (BCD) training iteration on a semi-parametric model. |
|
Combine two callback lists into one. |
|
Schedule the L1 rate based on the epoch number. |
|
Compute the mean squared error over the masked regions. |
|
Train either the parametric or non-parametric part of the PSF model using the specified parameters. |
Classes
|
L1 rate scheduler that adjusts the L1 rate during training according to a specified schedule. |
|
Computes the masked mean squared error (MSE) loss between predictions and targets. |
|
A custom metric class for computing the masked mean squared error (MSE). |
- class wf_psf.training.train_utils.L1ParamScheduler(*args: Any, **kwargs: Any)[source]
Bases:
CallbackL1 rate scheduler that adjusts the L1 rate during training according to a specified schedule.
This callback modifies the L1 regularization rate at each epoch based on the given scheduling function. The function takes the epoch index and the current L1 rate as inputs, and it outputs the updated L1 rate.
- Parameters:
l1_schedule_rule (function) –
A function that defines how to update the L1 rate. The function should take two arguments: - epoch (int): The current epoch index, starting from 0. - current_l1_rate (float): The L1 rate at the current epoch.
The function should return a new L1 rate (float) to be applied at the next epoch.
Methods
__call__(*args, **kwargs)Call self as a function.
on_epoch_begin(epoch[, logs])Execute callback function at the beginning of each epoch to adjust the L1 rate.
- on_epoch_begin(epoch, logs=None)[source]
Execute callback function at the beginning of each epoch to adjust the L1 rate.
This function gets the current L1 rate from the model’s optimizer, computes the new scheduled L1 rate using the l1_schedule_rule function, and sets it back to the model’s optimizer.
- class wf_psf.training.train_utils.MaskedMeanSquaredError(*args: Any, **kwargs: Any)[source]
Bases:
LossComputes the masked mean squared error (MSE) loss between predictions and targets.
This loss function assumes that y_true has two components in the last axis: - y_true[…, 0]: the target values. - y_true[…, 1]: a mask in [0, 1] where:
1 means the pixel is included in the loss.
0 means the pixel is ignored.
Values in (0,1) are treated as weights for partial contribution.
Methods
__call__(y_true, y_pred[, sample_weight])Invoke the loss computation with support for different shapes of inputs.
call(y_true, y_pred[, sample_weight])Compute the masked mean squared error loss.
- call(y_true: tensorflow.Tensor, y_pred: tensorflow.Tensor, sample_weight: tensorflow.Tensor | None = None) tensorflow.Tensor[source]
Compute the masked mean squared error loss.
- Parameters:
y_true (tf.Tensor) – Tensor of shape (batch, height, width, 2) containing the ground truth values and the mask.
y_pred (tf.Tensor) – Tensor of shape (batch, height, width) containing the predicted values.
sample_weight (tf.Tensor, optional) – Optional sample weights of shape (batch,). Broadcast over spatial dims.
- Returns:
Scalar tensor representing the mean squared error over masked pixels.
- Return type:
tf.Tensor
- class wf_psf.training.train_utils.MaskedMeanSquaredErrorMetric(*args: Any, **kwargs: Any)[source]
Bases:
MetricA custom metric class for computing the masked mean squared error (MSE).
This metric computes the mean squared error over the masked regions of the true values and predictions. A mask is applied such that a mask value of 1 excludes the pixel and 0 includes the pixel in the error computation. The metric is updated after every batch and returns the average masked MSE after processing the dataset.
Methods
__call__(*args, **kwargs)Call self as a function.
Reset the state of the metric, clearing the accumulated total loss and batch count.
result()Compute and return the current masked MSE value, averaged over all batches.
update_state(y_true, y_pred[, sample_weight])Update the state of the metric with the true values, predictions, and optional sample weights.
- reset_state() None[source]
Reset the state of the metric, clearing the accumulated total loss and batch count.
This method is typically called at the start of a new evaluation or after a new epoch.
- Return type:
None
- result() tensorflow.Tensor[source]
Compute and return the current masked MSE value, averaged over all batches.
- Returns:
The current masked MSE value (average loss per batch).
- Return type:
tf.Tensor
- update_state(y_true: tensorflow.Tensor, y_pred: tensorflow.Tensor, sample_weight: tensorflow.Tensor | None = None) None[source]
Update the state of the metric with the true values, predictions, and optional sample weights.
This method calculates the masked MSE loss for the given batch and accumulates the total loss and batch count.
- Parameters:
y_true (tf.Tensor) – True values with shape (batch_size, height, width, channels), where the last dimension contains both the target values and the mask.
y_pred (tf.Tensor) – Predicted values with shape (batch_size, height, width, channels).
sample_weight (tf.Tensor, optional) – Sample weights for each instance in the batch, with shape (batch_size,). If not provided, all instances are treated equally.
- Return type:
None
- wf_psf.training.train_utils.calculate_sample_weights(outputs: ndarray, use_sample_weights: bool, loss: str | Callable | None, apply_sigmoid: bool = False, sigmoid_max_val: float = 5.0, sigmoid_power_k: float = 1.0) ndarray | None[source]
Calculate sample weights based on image noise standard deviation.
The function computes sample weights by estimating the noise standard deviation for each image, calculating the inverse variance, and then normalizing the weights by dividing by the median.
- Parameters:
outputs (np.ndarray) – A 3D array of shape (batch_size, height, width) representing images, where the first dimension is the batch size and the next two dimensions are the image height and width.
use_sample_weights (bool) – Flag indicating whether to compute sample weights. If True, sample weights will be computed based on the image noise.
loss (str, callable, optional) – The loss function used for training. If the loss name is “masked_mean_squared_error”, the function will calculate the noise standard deviation for masked images.
apply_sigmoid (bool, optional) – Flag indicating whether to apply a generalized sigmoid function to the sample weights. Default is True.
sigmoid_max_val (float, optional) – The maximum value for the sigmoid function. Default is 5.0.
sigmoid_power_k (float, optional) – The power parameter for the sigmoid function. Default is 1.0. This parameter controls the steepness of the sigmoid curve. Higher values make the curve steeper.
- Returns:
An array of sample weights, or None if use_sample_weights is False.
- Return type:
np.ndarray or None
- wf_psf.training.train_utils.configure_optimizer_and_loss(learning_rate: float, optimizer: Callable | None = None, loss: Callable | None = None, metrics: list[Callable] | None = None) tuple[Callable, Callable, list[Callable]][source]
Configure and return the optimizer, loss function, and metrics for model training.
This function configures the optimizer, loss function, and metrics for either the parametric or non-parametric model components. If no optimizer, loss, or metrics are provided, default values are used.
- Parameters:
learning_rate (float) – The learning rate to be used by the optimizer.
optimizer (callable, optional) – A function or object used to initialize the optimizer (e.g., tf.keras.optimizers.Adam). If None, the default Adam optimizer with the specified learning rate is used.
loss (callable, optional) – The loss function to be used during training (e.g., tf.keras.losses.MeanSquaredError). If None, the default Mean Squared Error loss is used.
metrics (list of callable, optional) – A list of metric functions to evaluate during training (e.g., tf.keras.metrics.MeanSquaredError). If None, the default metric MeanSquaredError is used.
- Returns:
optimizer (callable) – The optimizer function or object configured for training.
loss (callable) – The loss function configured for training.
metrics (list of callable) – The list of metrics to be used for evaluating the model.
- wf_psf.training.train_utils.general_train_cycle(psf_model, inputs, outputs, validation_data, batch_size, learning_rate_param, learning_rate_non_param, n_epochs_param, n_epochs_non_param, param_optim=None, non_param_optim=None, param_loss=None, non_param_loss=None, param_metrics=None, non_param_metrics=None, param_callback=None, non_param_callback=None, general_callback=None, first_run=False, cycle_def='complete', use_sample_weights=False, apply_sigmoid=True, sigmoid_max_val=5.0, sigmoid_power_k=1.0, verbose=1)[source]
Perform a Bi-Cycle Descent (BCD) training iteration on a semi-parametric model.
The function alternates between optimizing the parametric and/or non-parametric parts of the model across specified training cycles. Each part of the model can be trained individually or together depending on the cycle_def parameter.
For the parametric part: - Default learning rate: learning_rate_param = 1e-2 - Default epochs: n_epochs_param = 20
For the non-parametric part: - Default learning rate: learning_rate_non_param = 1.0 - Default epochs: n_epochs_non_param = 100
- Parameters:
psf_model (tf.keras.Model) – A TensorFlow model representing the PSF (Point Spread Function), which may consist of both parametric and non-parametric components, or an individual component. These components are partitioned for training, with each part addressing different aspects of the PSF.
inputs (Tensor or list of tensors) – Input data for training (Model.fit()).
outputs (Tensor) – Output data for training (Model.fit()).
validation_data (Tuple) – Validation data used for model evaluation during training. (input_data, output_data).
batch_size (int) – The batch size for the training.
learning_rate_param (float) – Learning rate for the parametric part of the PSF model.
learning_rate_non_param (float) – Learning rate for the non-parametric part of the PSF model.
n_epochs_param (int) – Number of epochs to train the parametric part.
n_epochs_non_param (int) – Number of epochs to train the non-parametric part.
param_optim (tf.keras.optimizers.Optimizer, optional) – Optimizer for the parametric part. Defaults to Adam if not provided.
non_param_optim (tf.keras.optimizers.Optimizer, optional) – Optimizer for the non-parametric part. Defaults to Adam if not provided.
param_loss (tf.keras.losses.Loss, optional) – Loss function for the parametric part. Defaults to the MeanSquaredError().
non_param_loss (tf.keras.losses.Loss, optional) – Loss function for the non-parametric part. Defaults to MeanSquaredError().
param_metrics (list of tf.keras.metrics.Metric, optional) – List of metrics for the parametric part. Defaults to MeanSquaredError().
non_param_metrics (list of tf.keras.metrics.Metric, optional) – List of metrics for the non-parametric part. Defaults to MeanSquaredError().
param_callback (list of tf.keras.callbacks.Callback, optional) – Callback for the parametric part only. Defaults to no callback.
non_param_callback (list of tf.keras.callbacks.Callback, optional) – Callback for the non-parametric part only. Defaults to no callback.
general_callback (list of tf.keras.callbacks.Callback, optional) – Callback shared between both the parametric and non-parametric parts. Defaults to no callback.
first_run (bool, optional) – If True, the first iteration of training is assumed, and the non-parametric part is not considered during the parametric training. Default is False.
cycle_def (str, optional) – Defines the training cycle: parametric, non-parametric, complete, only-parametric, or only-non-parametric. The complete cycle trains both parts, while the others train only the specified part (both parametric and non-parametric). Default is complete.
use_sample_weights (bool, optional) – If True, sample weights are used in training. Sample weights are computed based on estimated noise variance. Default is False.
apply_sigmoid (bool, optional) – If True, a generalized sigmoid function is applied to the sample weights. Default is True.
sigmoid_max_val (float, optional) – The maximum value for the sigmoid function. Default is 5.0.
sigmoid_power_k (float, optional) – The power parameter for the sigmoid function. Default is 1.0. This parameter controls the steepness of the sigmoid curve.
verbose (int, optional) – Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. Default is 1.
- Returns:
psf_model (tf.keras.Model) – The trained PSF model.
hist_param (tf.keras.callbacks.History) – History object for the parametric training.
hist_non_param (tf.keras.callbacks.History) – History object for the non-parametric training.
- wf_psf.training.train_utils.get_callbacks(callback1, callback2)[source]
Combine two callback lists into one.
If both are None, returns None. If one is None, returns the other. Otherwise, combines both lists.
- Parameters:
- Returns:
The combined list of callbacks or None.
- Return type:
list of tf.keras.callbacks.Callback or None
- wf_psf.training.train_utils.l1_schedule_rule(epoch_n: int, l1_rate: float) float[source]
Schedule the L1 rate based on the epoch number.
If the current epoch is a multiple of 10 (except for the first epoch), the L1 rate is halved. Otherwise, the L1 rate remains unchanged.
- wf_psf.training.train_utils.masked_mse(y_true: tensorflow.Tensor, y_pred: tensorflow.Tensor, mask: tensorflow.Tensor, sample_weight: tensorflow.Tensor | None = None) tensorflow.Tensor[source]
Compute the mean squared error over the masked regions.
- Parameters:
y_true (tf.Tensor) – True values with shape (batch, height, width).
y_pred (tf.Tensor) – Predicted values with shape (batch, height, width).
mask (tf.Tensor) – A mask to apply, which can contain float values in [0,1]. - 0 means to include the pixel. - 1 means to ignore the pixel. - Values in (0,1) act as weights for partial consideration.
sample_weight (tf.Tensor, optional) – Sample weights for each image in the batch, with shape (batch,). If provided, it is broadcasted over the spatial dimensions.
- Returns:
The mean squared error computed over the masked regions.
- Return type:
tf.Tensor
- wf_psf.training.train_utils.train_cycle_part(psf_model: tensorflow.keras.Model, inputs: tensorflow.Tensor, outputs: tensorflow.Tensor, batch_size: int, epochs: int, optimizer: tensorflow.keras.optimizers.Optimizer, loss: Callable, metrics: list[Callable], validation_data: tuple[tensorflow.Tensor, tensorflow.Tensor] | None = None, callbacks: list[Callable] | None = None, sample_weight: tensorflow.Tensor | None = None, verbose: int = 1, cycle_part: str = 'parametric') tensorflow.keras.Model[source]
Train either the parametric or non-parametric part of the PSF model using the specified parameters. This function trains a single component of the model (either parametric or non-parametric) based on the provided configuration.
- Parameters:
psf_model (tf.keras.Model) – A TensorFlow model representing the PSF (Point Spread Function), which consists of either a parametric or a non-parametric component.
inputs (tf.Tensor) – Input data for training the model. Expected to be a tensor with the shape of the input batch.
outputs (tf.Tensor) – Target output data for training the model. Expected to match the shape of inputs.
batch_size (int) – The number of samples per batch during training.
epochs (int) – The number of epochs to train the model.
optimizer (tf.keras.optimizers.Optimizer) – The optimizer used for training the model (e.g., Adam, SGD).
loss (Callable) – The loss function used for training the model. Typically a callable like tf.keras.losses.MeanSquaredError().
metrics (list of Callable) – List of metrics to monitor during training. Each element should be a callable metric (e.g., accuracy, precision).
validation_data (tuple of (tf.Tensor, tf.Tensor), optional) – Tuple of input and output tensors to evaluate the model during training. Default is None.
callbacks (list of Callable, optional) – List of callbacks to apply during training, such as tf.keras.callbacks.EarlyStopping. Default is None.
sample_weight (tf.Tensor, optional) – Weights for the samples during training. Default is None.
verbose (int, optional) – Verbosity mode (0, 1, or 2). Default is 1.
cycle_part (str, optional) – Specifies which part of the model to train (“parametric” or “non-parametric”). Default is “parametric”.
- Returns:
The trained TensorFlow model after completing the specified number of epochs.
- Return type:
tf.keras.Model
Notes
This function trains the model based on the provided cycle_part. If cycle_part is set to “parametric”, the function assumes the model is being trained in a parametric setting, while “non-parametric” indicates the training of a non-parametric part. The model is built using the build_PSF_model function before fitting.
Examples
- model = train_cycle_part(
psf_model=model, inputs=train_inputs, outputs=train_outputs, batch_size=32, epochs=10, optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.MeanSquaredError(), metrics=[tf.keras.metrics.MeanAbsoluteError()], validation_data=(val_inputs, val_outputs), callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)], sample_weight=None, verbose=1
)