"""
Training utilities for the PSF model.
This module contains helper functions and utilities related to the training
process for the PSF model. These functions help with managing training cycles,
callbacks, and related operations.
Authors: Tobias Liaudat <tobias.liaudat@cea.fr>, Ezequiel Centofanti <ezequiel.centofanti@cea.fr>,
Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import numpy as np
import tensorflow as tf
from typing import Optional, Callable, Union
from wf_psf.psf_models.psf_models import compile_PSF_model
from wf_psf.utils.utils import NoiseEstimator, generalised_sigmoid
import logging
logger = logging.getLogger(__name__)
[docs]
class L1ParamScheduler(tf.keras.callbacks.Callback):
"""L1 rate scheduler that adjusts the L1 rate during training according to a specified schedule.
This callback modifies the L1 regularization rate at each epoch based on the given scheduling function.
The function takes the epoch index and the current L1 rate as inputs, and it outputs the updated L1 rate.
Parameters
----------
l1_schedule_rule: function
A function that defines how to update the L1 rate. The function should take two arguments:
- `epoch` (int): The current epoch index, starting from 0.
- `current_l1_rate` (float): The L1 rate at the current epoch.
The function should return a new L1 rate (float) to be applied at the next epoch.
"""
def __init__(self, l1_schedule_rule):
"""
Initialize the L1ParamScheduler.
Parameters
----------
l1_schedule_rule : function
A function that defines how to update the L1 rate at each epoch. See class docstring for details.
"""
super().__init__()
self.l1_schedule_rule = l1_schedule_rule
[docs]
def on_epoch_begin(self, epoch, logs=None):
"""
Execute callback function at the beginning of each epoch to adjust the L1 rate.
This function gets the current L1 rate from the model's optimizer, computes the new scheduled
L1 rate using the `l1_schedule_rule` function, and sets it back to the model's optimizer.
Parameters
----------
epoch: int
The current epoch index, starting from 0.
logs: dict, optional
A dictionary containing logs for the current epoch (default is None).
"""
# Get the current learning rate from model's optimizer.
l1_rate = float(tf.keras.backend.get_value(self.model.l1_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_l1_rate = self.l1_schedule_rule(epoch, l1_rate)
# Set the value back to the optimizer before this epoch starts
self.model.set_l1_rate(scheduled_l1_rate)
# tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
[docs]
def masked_mse(
y_true: tf.Tensor,
y_pred: tf.Tensor,
mask: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> tf.Tensor:
"""Compute the mean squared error over the masked regions.
Parameters
----------
y_true : tf.Tensor
True values with shape (batch, height, width).
y_pred : tf.Tensor
Predicted values with shape (batch, height, width).
mask : tf.Tensor
A mask to apply, which **can contain float values in [0,1]**.
- `0` means to include the pixel.
- `1` means to ignore the pixel.
- Values in `(0,1)` act as weights for partial consideration.
sample_weight : tf.Tensor, optional
Sample weights for each image in the batch, with shape (batch,).
If provided, it is broadcasted over the spatial dimensions.
Returns
-------
tf.Tensor
The mean squared error computed over the masked regions.
"""
# Compute the squared error and apply the mask
error = (1 - mask) * tf.square(y_true - y_pred) # (batch, height, width)
# Apply sample weights if provided
if sample_weight is not None:
error *= tf.reshape(sample_weight, (-1, 1, 1))
# Sum over spatial dimensions to compute the mask weight
mask_sum = tf.reduce_sum((1 - mask), axis=[1, 2]) # (batch,)
# Compute the weighted mean squared error
return tf.reduce_sum(error / tf.reshape(mask_sum, (-1, 1, 1))) / tf.cast(
tf.shape(y_true)[0], y_true.dtype
)
[docs]
class MaskedMeanSquaredError(tf.keras.losses.Loss):
"""
Computes the masked mean squared error (MSE) loss between predictions and targets.
This loss function assumes that `y_true` has two components in the last axis:
- `y_true[..., 0]`: the target values.
- `y_true[..., 1]`: a mask in [0, 1] where:
- `1` means the pixel is included in the loss.
- `0` means the pixel is ignored.
- Values in (0,1) are treated as weights for partial contribution.
"""
def __init__(self, name: str = "masked_mean_squared_error", **kwargs):
"""
Initialize the masked mean squared error loss.
Parameters
----------
name : str, optional
Name of the loss function. Default is "masked_mean_squared_error".
**kwargs : dict
Additional keyword arguments passed to `tf.keras.losses.Loss`.
"""
super().__init__(name=name, **kwargs)
def __call__(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> tf.Tensor:
"""
Invoke the loss computation with support for different shapes of inputs.
Parameters
----------
y_true : tf.Tensor
A tensor of shape (batch, height, width, 2), where the last channel
contains the true values and the mask.
y_pred : tf.Tensor
A tensor of shape (batch, height, width), containing the predicted values.
sample_weight : tf.Tensor, optional
Optional per-sample weighting tensor of shape (batch,).
Returns
-------
tf.Tensor
Scalar tensor representing the final masked MSE loss.
"""
return self.call(y_true, y_pred, sample_weight)
[docs]
def call(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> tf.Tensor:
"""
Compute the masked mean squared error loss.
Parameters
----------
y_true : tf.Tensor
Tensor of shape (batch, height, width, 2) containing the ground truth
values and the mask.
y_pred : tf.Tensor
Tensor of shape (batch, height, width) containing the predicted values.
sample_weight : tf.Tensor, optional
Optional sample weights of shape (batch,). Broadcast over spatial dims.
Returns
-------
tf.Tensor
Scalar tensor representing the mean squared error over masked pixels.
"""
# Extract the target and the masks from y_true
y_target = y_true[..., 0]
mask = y_true[..., 1]
return masked_mse(y_target, y_pred, mask, sample_weight)
[docs]
class MaskedMeanSquaredErrorMetric(tf.keras.metrics.Metric):
"""
A custom metric class for computing the masked mean squared error (MSE).
This metric computes the mean squared error over the masked regions of the
true values and predictions. A mask is applied such that a mask value of
`1` excludes the pixel and `0` includes the pixel in the error computation.
The metric is updated after every batch and returns the average masked MSE
after processing the dataset.
"""
def __init__(self, name: str = "masked_mean_squared_error", **kwargs) -> tf.Tensor:
"""
Initialize the MaskedMeanSquaredErrorMetric.
Parameters
----------
name : str, optional
The name of the metric instance. Defaults to "masked_mean_squared_error".
**kwargs : additional keyword arguments
Additional arguments passed to the parent class initializer.
"""
super().__init__(name=name, **kwargs)
self.total_loss = self.add_weight(name="total_loss", initializer="zeros")
self.batch_count = self.add_weight(name="batch_count", initializer="zeros")
[docs]
def update_state(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> None:
"""
Update the state of the metric with the true values, predictions, and optional sample weights.
This method calculates the masked MSE loss for the given batch and accumulates the total loss and batch count.
Parameters
----------
y_true : tf.Tensor
True values with shape `(batch_size, height, width, channels)`,
where the last dimension contains both the target values and the mask.
y_pred : tf.Tensor
Predicted values with shape `(batch_size, height, width, channels)`.
sample_weight : tf.Tensor, optional
Sample weights for each instance in the batch, with shape `(batch_size,)`.
If not provided, all instances are treated equally.
Returns
-------
None
"""
# Extract the target and the masks from y_true
y_target = y_true[..., 0]
mask = y_true[..., 1]
loss = masked_mse(y_target, y_pred, mask, sample_weight)
self.total_loss.assign_add(loss)
self.batch_count.assign_add(1.0)
[docs]
def result(self) -> tf.Tensor:
"""
Compute and return the current masked MSE value, averaged over all batches.
Returns
-------
tf.Tensor
The current masked MSE value (average loss per batch).
"""
return self.total_loss / self.batch_count
[docs]
def reset_state(self) -> None:
"""
Reset the state of the metric, clearing the accumulated total loss and batch count.
This method is typically called at the start of a new evaluation or after
a new epoch.
Returns
-------
None
"""
self.total_loss.assign(0.0)
self.batch_count.assign(0.0)
[docs]
def l1_schedule_rule(epoch_n: int, l1_rate: float) -> float:
"""
Schedule the L1 rate based on the epoch number.
If the current epoch is a multiple of 10 (except for the first epoch),
the L1 rate is halved. Otherwise, the L1 rate remains unchanged.
Parameters
----------
epoch_n: int
The current epoch number, where the epoch index starts from 0.
l1_rate: float
The current L1 regularization rate.
Returns
-------
float
The updated L1 rate for the given epoch.
"""
if epoch_n != 0 and epoch_n % 10 == 0:
scheduled_l1_rate = l1_rate / 2
logger.info(f"Epoch {epoch_n:05d}: L1 rate is {scheduled_l1_rate:0.4e}.")
return scheduled_l1_rate
return l1_rate
[docs]
def calculate_sample_weights(
outputs: np.ndarray,
use_sample_weights: bool,
loss: Union[str, Callable, None],
apply_sigmoid: bool = False,
sigmoid_max_val: float = 5.0,
sigmoid_power_k: float = 1.0,
) -> Optional[np.ndarray]:
"""
Calculate sample weights based on image noise standard deviation.
The function computes sample weights by estimating the noise standard deviation for each image, calculating the inverse variance,
and then normalizing the weights by dividing by the median.
Parameters
----------
outputs: np.ndarray
A 3D array of shape (batch_size, height, width) representing images, where the first dimension is the batch size
and the next two dimensions are the image height and width.
use_sample_weights: bool
Flag indicating whether to compute sample weights. If True, sample weights will be computed based on the image noise.
loss: str, callable, optional
The loss function used for training. If the loss name is "masked_mean_squared_error", the function will calculate the noise standard deviation for masked images.
apply_sigmoid: bool, optional
Flag indicating whether to apply a generalized sigmoid function to the sample weights. Default is True.
sigmoid_max_val: float, optional
The maximum value for the sigmoid function. Default is 5.0.
sigmoid_power_k: float, optional
The power parameter for the sigmoid function. Default is 1.0.
This parameter controls the steepness of the sigmoid curve.
Higher values make the curve steeper.
Returns
-------
np.ndarray or None
An array of sample weights, or None if `use_sample_weights` is False.
"""
if use_sample_weights:
img_dim = (outputs.shape[1], outputs.shape[2])
win_rad = np.ceil(outputs.shape[1] / 3.33)
std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad)
if loss is not None and (
(isinstance(loss, str) and loss == "masked_mean_squared_error")
or (hasattr(loss, "name") and loss.name == "masked_mean_squared_error")
):
logger.info("Estimating noise standard deviation for masked images..")
images = outputs[..., 0]
masks = np.array(1 - outputs[..., 1], dtype=bool)
imgs_std = np.array(
[std_est.estimate_noise(_im, _win) for _im, _win in zip(images, masks)]
)
else:
logger.info("Estimating noise standard deviation for images..")
# Estimate noise standard deviation
imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs])
# Calculate variances
variances = imgs_std**2
# Use inverse variance for weights and scale by median
sample_weight = 1 / variances
sample_weight /= np.median(sample_weight)
# Apply a generalized sigmoid function to the sample weights
if apply_sigmoid:
sample_weight = generalised_sigmoid(
sample_weight, max_val=sigmoid_max_val, power_k=sigmoid_power_k
)
else:
sample_weight = None
return sample_weight
[docs]
def train_cycle_part(
psf_model: tf.keras.Model,
inputs: tf.Tensor,
outputs: tf.Tensor,
batch_size: int,
epochs: int,
optimizer: tf.keras.optimizers.Optimizer,
loss: Callable,
metrics: list[Callable],
validation_data: Optional[tuple[tf.Tensor, tf.Tensor]] = None,
callbacks: Optional[list[Callable]] = None,
sample_weight: Optional[tf.Tensor] = None,
verbose: int = 1,
cycle_part: str = "parametric",
) -> tf.keras.Model:
"""
Train either the parametric or non-parametric part of the PSF model using the specified parameters. This function trains a single component of the model (either parametric or non-parametric) based on the provided configuration.
Parameters
----------
psf_model: tf.keras.Model
A TensorFlow model representing the PSF (Point Spread Function), which consists of either a parametric or a non-parametric component.
inputs: tf.Tensor
Input data for training the model. Expected to be a tensor with the shape of the input batch.
outputs: tf.Tensor
Target output data for training the model. Expected to match the shape of `inputs`.
batch_size: int
The number of samples per batch during training.
epochs: int
The number of epochs to train the model.
optimizer: tf.keras.optimizers.Optimizer
The optimizer used for training the model (e.g., Adam, SGD).
loss: Callable
The loss function used for training the model. Typically a callable like `tf.keras.losses.MeanSquaredError()`.
metrics: list of Callable
List of metrics to monitor during training. Each element should be a callable metric (e.g., accuracy, precision).
validation_data: tuple of (tf.Tensor, tf.Tensor), optional
Tuple of input and output tensors to evaluate the model during training. Default is None.
callbacks: list of Callable, optional
List of callbacks to apply during training, such as `tf.keras.callbacks.EarlyStopping`. Default is None.
sample_weight: tf.Tensor, optional
Weights for the samples during training. Default is None.
verbose: int, optional
Verbosity mode (0, 1, or 2). Default is 1.
cycle_part: str, optional
Specifies which part of the model to train ("parametric" or "non-parametric"). Default is "parametric".
Returns
-------
tf.keras.Model
The trained TensorFlow model after completing the specified number of epochs.
Notes
-----
This function trains the model based on the provided `cycle_part`. If `cycle_part` is set to
"parametric", the function assumes the model is being trained in a parametric setting, while
"non-parametric" indicates the training of a non-parametric part. The model is compile using the
`compile_PSF_model` function before fitting.
Examples
--------
model = train_cycle_part(
psf_model=model,
inputs=train_inputs,
outputs=train_outputs,
batch_size=32,
epochs=10,
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.MeanSquaredError(),
metrics=[tf.keras.metrics.MeanAbsoluteError()],
validation_data=(val_inputs, val_outputs),
callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)],
sample_weight=None,
verbose=1
)
"""
logger.info(f"Starting {cycle_part} update..")
psf_model = compile_PSF_model(
psf_model, optimizer=optimizer, loss=loss, metrics=metrics
)
return psf_model.fit(
x=inputs,
y=outputs,
batch_size=batch_size,
epochs=epochs,
validation_data=validation_data,
callbacks=callbacks,
sample_weight=sample_weight,
verbose=verbose,
)
[docs]
def get_callbacks(callback1, callback2):
"""
Combine two callback lists into one.
If both are None, returns None. If one is None, returns the other.
Otherwise, combines both lists.
Parameters
----------
callback1: list of tf.keras.callbacks.Callback or None
The first callback list (e.g., parametric or non-parametric).
callback2: list of tf.keras.callbacks.Callback or None
The second callback list (e.g., general callback).
Returns
-------
list of tf.keras.callbacks.Callback or None
The combined list of callbacks or None.
"""
if callback1 is None and callback2 is None:
return None
return (callback1 or []) + (callback2 or [])
[docs]
def general_train_cycle(
psf_model,
inputs,
outputs,
validation_data,
batch_size,
learning_rate_param,
learning_rate_non_param,
n_epochs_param,
n_epochs_non_param,
param_optim=None,
non_param_optim=None,
param_loss=None,
non_param_loss=None,
param_metrics=None,
non_param_metrics=None,
param_callback=None,
non_param_callback=None,
general_callback=None,
first_run=False,
cycle_def="complete",
use_sample_weights=False,
apply_sigmoid=True,
sigmoid_max_val=5.0,
sigmoid_power_k=1.0,
verbose=1,
):
"""
Perform a Bi-Cycle Descent (BCD) training iteration on a semi-parametric model.
The function alternates between optimizing the parametric and/or non-parametric parts of the model
across specified training cycles. Each part of the model can be trained individually or together
depending on the `cycle_def` parameter.
For the parametric part:
- Default learning rate: `learning_rate_param = 1e-2`
- Default epochs: `n_epochs_param = 20`
For the non-parametric part:
- Default learning rate: `learning_rate_non_param = 1.0`
- Default epochs: `n_epochs_non_param = 100`
Parameters
----------
psf_model: tf.keras.Model
A TensorFlow model representing the PSF (Point Spread Function), which may consist of both parametric and non-parametric components, or an individual component. These components are partitioned for training, with each part addressing different aspects of the PSF.
inputs: Tensor or list of tensors
Input data for training (`Model.fit()`).
outputs: Tensor
Output data for training (`Model.fit()`).
validation_data: Tuple
Validation data used for model evaluation during training.
(input_data, output_data).
batch_size: int
The batch size for the training.
learning_rate_param: float
Learning rate for the parametric part of the PSF model.
learning_rate_non_param: float
Learning rate for the non-parametric part of the PSF model.
n_epochs_param: int
Number of epochs to train the parametric part.
n_epochs_non_param: int
Number of epochs to train the non-parametric part.
param_optim: tf.keras.optimizers.Optimizer, optional
Optimizer for the parametric part. Defaults to Adam if not provided.
non_param_optim: tf.keras.optimizers.Optimizer, optional
Optimizer for the non-parametric part. Defaults to Adam if not provided.
param_loss: tf.keras.losses.Loss, optional
Loss function for the parametric part. Defaults to the MeanSquaredError().
non_param_loss: tf.keras.losses.Loss, optional
Loss function for the non-parametric part. Defaults to MeanSquaredError().
param_metrics: list of tf.keras.metrics.Metric, optional
List of metrics for the parametric part. Defaults to MeanSquaredError().
non_param_metrics: list of tf.keras.metrics.Metric, optional
List of metrics for the non-parametric part. Defaults to MeanSquaredError().
param_callback: list of tf.keras.callbacks.Callback, optional
Callback for the parametric part only. Defaults to no callback.
non_param_callback: list of tf.keras.callbacks.Callback, optional
Callback for the non-parametric part only. Defaults to no callback.
general_callback: list of tf.keras.callbacks.Callback, optional
Callback shared between both the parametric and non-parametric parts. Defaults to no callback.
first_run: bool, optional
If True, the first iteration of training is assumed, and the non-parametric part
is not considered during the parametric training. Default is False.
cycle_def: str, optional
Defines the training cycle: `parametric`, `non-parametric`, `complete`, `only-parametric`, or `only-non-parametric`.
The `complete` cycle trains both parts, while the others train only the specified part (both parametric and non-parametric). Default is `complete`.
use_sample_weights: bool, optional
If True, sample weights are used in training. Sample weights are computed
based on estimated noise variance. Default is False.
apply_sigmoid: bool, optional
If True, a generalized sigmoid function is applied to the sample weights.
Default is True.
sigmoid_max_val: float, optional
The maximum value for the sigmoid function. Default is `5.0`.
sigmoid_power_k: float, optional
The power parameter for the sigmoid function. Default is `1.0`.
This parameter controls the steepness of the sigmoid curve.
verbose: int, optional
Verbosity mode. `0` = silent, `1` = progress bar, `2` = one line per epoch.
Default is 1.
Returns
-------
psf_model: tf.keras.Model
The trained PSF model.
hist_param: tf.keras.callbacks.History
History object for the parametric training.
hist_non_param: tf.keras.callbacks.History
History object for the non-parametric training.
"""
# Initialize return variables
hist_param, hist_non_param = None, None
# Parametric part
optimizer, loss, metrics = configure_optimizer_and_loss(
learning_rate_param, param_optim, param_loss, param_metrics
)
# Calculate sample weights
sample_weight = calculate_sample_weights(
outputs,
use_sample_weights,
loss,
apply_sigmoid,
sigmoid_max_val,
sigmoid_power_k,
)
# Define the training cycle
if cycle_def in ("parametric", "complete", "only-parametric"):
# If it is the first run
if first_run:
# Set the non-parametric model to zero
# With alpha to zero its already enough
psf_model.set_zero_nonparam()
if cycle_def == "only-parametric":
# Set the non-parametric part to zero
psf_model.set_zero_nonparam()
# Define callbacks for parametric part
# If both are None, set callbacks to None
callbacks = get_callbacks(param_callback, general_callback)
# Set the trainable layer
psf_model.set_trainable_layers(param_bool=True, nonparam_bool=False)
hist_param = train_cycle_part(
psf_model=psf_model,
inputs=inputs,
outputs=outputs,
batch_size=batch_size,
epochs=n_epochs_param,
optimizer=optimizer,
loss=loss,
metrics=metrics,
validation_data=validation_data,
callbacks=callbacks,
sample_weight=sample_weight,
verbose=verbose,
cycle_part="parametric",
)
# Non-parametric part
optimizer, loss, metrics = configure_optimizer_and_loss(
learning_rate_non_param,
non_param_optim,
non_param_loss,
non_param_metrics,
)
if cycle_def in ("non-parametric", "complete", "only-non-parametric"):
if first_run:
# Set the non-parametric model to non-zero
# With alpha to zero its already enough
psf_model.set_nonzero_nonparam()
if cycle_def == "only-non-parametric":
# Set the parametric layer to zero
coeff_mat = psf_model.get_coeff_matrix()
psf_model.assign_coeff_matrix(tf.zeros_like(coeff_mat))
# Define callbacks for non-parametric part
# If both are None, set callbacks to None
callbacks = get_callbacks(non_param_callback, general_callback)
psf_model.set_trainable_layers(param_bool=False, nonparam_bool=True)
hist_non_param = train_cycle_part(
psf_model=psf_model,
inputs=inputs,
outputs=outputs,
batch_size=batch_size,
epochs=n_epochs_non_param,
optimizer=optimizer,
loss=loss,
metrics=metrics,
validation_data=validation_data,
callbacks=callbacks,
sample_weight=sample_weight,
verbose=verbose,
cycle_part="non-parametric",
)
return psf_model, hist_param, hist_non_param