Source code for wf_psf.training.train_utils

import numpy as np
import tensorflow as tf
from wf_psf.psf_models.tf_psf_field import build_PSF_model
from wf_psf.utils.utils import NoiseEstimator
import logging

logger = logging.getLogger(__name__)


[docs] class L1ParamScheduler(tf.keras.callbacks.Callback): """L1 rate scheduler which sets the L1 rate according to schedule. Parameters ---------- l1_schedule_rule: function a function that takes an epoch index (integer, indexed from 0) and current l1_rate as inputs and returns a new l1_rate as output (float). """ def __init__(self, l1_schedule_rule): super(L1ParamScheduler, self).__init__() breakpoint() self.l1_schedule_rule = l1_schedule_rule
[docs] def on_epoch_begin(self, epoch, logs=None): # Get the current learning rate from model's optimizer. l1_rate = float(tf.keras.backend.get_value(self.model.l1_rate)) # Call schedule function to get the scheduled learning rate. scheduled_l1_rate = self.l1_schedule_rule(epoch, l1_rate) # Set the value back to the optimizer before this epoch starts self.model.set_l1_rate(scheduled_l1_rate)
# tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
[docs] def l1_schedule_rule(epoch_n, l1_rate): if epoch_n != 0 and epoch_n % 10 == 0: scheduled_l1_rate = l1_rate / 2 logger.info("\nEpoch %05d: L1 rate is %0.4e." % (epoch_n, scheduled_l1_rate)) return scheduled_l1_rate else: return l1_rate
[docs] def general_train_cycle( tf_semiparam_field, inputs, outputs, validation_data, batch_size, learning_rate_param, learning_rate_non_param, n_epochs_param, n_epochs_non_param, param_optim=None, non_param_optim=None, param_loss=None, non_param_loss=None, param_metrics=None, non_param_metrics=None, param_callback=None, non_param_callback=None, general_callback=None, first_run=False, cycle_def="complete", use_sample_weights=False, verbose=1, ): """Function to do a BCD iteration on the model. Define the model optimisation. For the parametric part we are using: ``learning_rate_param = 1e-2``, ``n_epochs_param = 20``. For the non-parametric part we are using: ``learning_rate_non_param = 1.0``, ``n_epochs_non_param = 100``. Parameters ---------- tf_semiparam_field: tf.keras.Model The model to be trained. inputs: Tensor or list of tensors Inputs used for Model.fit() outputs: Tensor Outputs used for Model.fit() validation_data: Tuple Validation test data used for Model.fit(). Tuple of input, output validation data batch_size: int Batch size for the training. learning_rate_param: float Learning rate for the parametric part learning_rate_non_param: float Learning rate for the non-parametric part n_epochs_param: int Number of epochs for the parametric part n_epochs_non_param: int Number of epochs for the non-parametric part param_optim: Tensorflow optimizer Optimizer for the parametric part. Optional, default is the Adam optimizer non_param_optim: Tensorflow optimizer Optimizer for the non-parametric part. Optional, default is the Adam optimizer param_loss: Tensorflow loss Loss function for the parametric part. Optional, default is the MeanSquaredError() loss non_param_loss: Tensorflow loss Loss function for the non-parametric part. Optional, default is the MeanSquaredError() loss param_metrics: Tensorflow metrics Metrics for the parametric part. Optional, default is the MeanSquaredError() metric non_param_metrics: Tensorflow metrics Metrics for the non-parametric part. Optional, default is the MeanSquaredError() metric param_callback: Tensorflow callback Callback for the parametric part only. Optional, default is no callback non_param_callback: Tensorflow callback Callback for the non-parametric part only. Optional, default is no callback general_callback: Tensorflow callback Callback shared for both the parametric and non-parametric parts. Optional, default is no callback first_run: bool If True, it is the first iteration of the model training. The Non-parametric part is not considered in the first parametric training. cycle_def: string Train cycle definition. It can be: `parametric`, `non-parametric`, `complete`. Default is `complete`. use_sample_weights: bool If True, the sample weights are used for the training. The sample weights are computed as the inverse noise estimated variance verbose: int Verbosity mode used for the training procedure. If a log of the training is being saved, `verbose=2` is recommended. Returns ------- tf_semiparam_field: tf.keras.Model Trained Tensorflow model. hist_param: Tensorflow's History object History of the parametric training. hist_non_param: Tensorflow's History object History of the non-parametric training. """ # Initialize return variables hist_param = None hist_non_param = None # Parametric train # Define Loss if param_loss is None: loss = tf.keras.losses.MeanSquaredError() else: loss = param_loss # Define optimisers if param_optim is None: optimizer = tf.keras.optimizers.Adam( learning_rate=learning_rate_param, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False, ) else: optimizer = param_optim # Define metrics if param_metrics is None: metrics = [tf.keras.metrics.MeanSquaredError()] else: metrics = param_metrics # Define callbacks if param_callback is None and general_callback is None: callbacks = None else: if general_callback is None: callbacks = param_callback elif param_callback is None: callbacks = general_callback else: callbacks = general_callback + param_callback # Calculate sample weights if use_sample_weights: # Generate standard deviation estimator img_dim = (outputs.shape[1], outputs.shape[2]) win_rad = np.ceil(outputs.shape[1] / 3.33) std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) # Estimate noise std_dev imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs]) # Calculate weights variances = imgs_std**2 # Define sample weight strategy strategy_opt = 1 if strategy_opt == 0: # Parameters max_w = 2.0 min_w = 0.1 # Epsilon is to avoid outliers epsilon = np.median(variances) * 0.1 w = 1 / (variances + epsilon) scaled_w = (w - np.min(w)) / (np.max(w) - np.min(w)) # Transform to [0,1] scaled_w = scaled_w * (max_w - min_w) + min_w # Transform to [min_w, max_w] scaled_w = scaled_w + (1 - np.mean(scaled_w)) # Adjust the mean to 1 scaled_w[scaled_w < min_w] = min_w # Save the weights sample_weight = scaled_w elif strategy_opt == 1: # Use inverse variance for weights # Then scale the values by the median sample_weight = 1 / variances sample_weight /= np.median(sample_weight) else: sample_weight = None # Define the training cycle if ( cycle_def == "parametric" or cycle_def == "complete" or cycle_def == "only-parametric" ): # If it is the first run if first_run: # Set the non-parametric model to zero # With alpha to zero its already enough tf_semiparam_field.set_zero_nonparam() if cycle_def == "only-parametric": # Set the non-parametric part to zero tf_semiparam_field.set_zero_nonparam() # Set the trainable layer tf_semiparam_field.set_trainable_layers(param_bool=True, nonparam_bool=False) # Compile the model for the first optimisation tf_semiparam_field = build_PSF_model( tf_semiparam_field, optimizer=optimizer, loss=loss, metrics=metrics, ) # Train the parametric part logger.info("Starting parametric update..") hist_param = tf_semiparam_field.fit( x=inputs, y=outputs, batch_size=batch_size, epochs=n_epochs_param, validation_data=validation_data, callbacks=callbacks, sample_weight=sample_weight, verbose=verbose, ) ## Non parametric train # Define the training cycle if ( cycle_def == "non-parametric" or cycle_def == "complete" or cycle_def == "only-non-parametric" ): # If it is the first run if first_run: # Set the non-parametric model to non-zero # With alpha to zero its already enough tf_semiparam_field.set_nonzero_nonparam() if cycle_def == "only-non-parametric": # Set the parametric layer to zero coeff_mat = tf_semiparam_field.get_coeff_matrix() tf_semiparam_field.assign_coeff_matrix(tf.zeros_like(coeff_mat)) # Set the non parametric layer to non trainable tf_semiparam_field.set_trainable_layers(param_bool=False, nonparam_bool=True) # Define Loss if non_param_loss is None: loss = tf.keras.losses.MeanSquaredError() else: loss = non_param_loss # Define optimiser if non_param_optim is None: optimizer = tf.keras.optimizers.Adam( learning_rate=learning_rate_non_param, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False, ) else: optimizer = non_param_optim # Define metric if non_param_metrics is None: metrics = [tf.keras.metrics.MeanSquaredError()] else: metrics = non_param_metrics # Define callbacks if non_param_callback is None and general_callback is None: callbacks = None else: if general_callback is None: callbacks = non_param_callback elif non_param_callback is None: callbacks = general_callback else: callbacks = general_callback + non_param_callback # Compile the model again for the second optimisation tf_semiparam_field = build_PSF_model( tf_semiparam_field, optimizer=optimizer, loss=loss, metrics=metrics, ) # Train the nonparametric part logger.info("Starting non-parametric update..") hist_non_param = tf_semiparam_field.fit( x=inputs, y=outputs, batch_size=batch_size, epochs=n_epochs_non_param, validation_data=validation_data, callbacks=callbacks, sample_weight=sample_weight, verbose=verbose, ) return tf_semiparam_field, hist_param, hist_non_param
[docs] def param_train_cycle( tf_semiparam_field, inputs, outputs, validation_data, batch_size, learning_rate, n_epochs, param_optim=None, param_loss=None, param_metrics=None, param_callback=None, general_callback=None, use_sample_weights=False, verbose=1, ): """Training cycle for parametric model.""" # Define Loss if param_loss is None: loss = tf.keras.losses.MeanSquaredError() else: loss = param_loss # Define optimiser if param_optim is None: optimizer = tf.keras.optimizers.Adam( learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False, ) else: optimizer = param_optim # Define metrics if param_metrics is None: metrics = [tf.keras.metrics.MeanSquaredError()] else: metrics = param_metrics # Define callbacks if param_callback is None and general_callback is None: callbacks = None else: if general_callback is None: callbacks = param_callback elif param_callback is None: callbacks = general_callback else: callbacks = general_callback + param_callback # Calculate sample weights if use_sample_weights: # Generate standard deviation estimator img_dim = (outputs.shape[1], outputs.shape[2]) win_rad = np.ceil(outputs.shape[1] / 3.33) std_est = NoiseEstimator(img_dim=img_dim, win_rad=win_rad) # Estimate noise std_dev imgs_std = np.array([std_est.estimate_noise(_im) for _im in outputs]) # Calculate weights variances = imgs_std**2 strategy_opt = 1 if strategy_opt == 0: # Parameters max_w = 2.0 min_w = 0.1 # Epsilon is to avoid outliers epsilon = np.median(variances) * 0.1 w = 1 / (variances + epsilon) scaled_w = (w - np.min(w)) / (np.max(w) - np.min(w)) # Transform to [0,1] scaled_w = scaled_w * (max_w - min_w) + min_w # Transform to [min_w, max_w] scaled_w = scaled_w + (1 - np.mean(scaled_w)) # Adjust the mean to 1 scaled_w[scaled_w < min_w] = min_w # Save the weights sample_weight = scaled_w elif strategy_opt == 1: # Use inverse variance for weights # Then scale the values by the median sample_weight = 1 / variances sample_weight /= np.median(sample_weight) else: sample_weight = None # Compile the model for the first optimisation tf_semiparam_field = build_PSF_model( tf_semiparam_field, optimizer=optimizer, loss=loss, metrics=metrics ) # Train the parametric part logger.info("Starting parametric update..") hist_param = tf_semiparam_field.fit( x=inputs, y=outputs, batch_size=batch_size, epochs=n_epochs, validation_data=validation_data, callbacks=callbacks, sample_weight=sample_weight, verbose=verbose, ) return tf_semiparam_field, hist_param