"""Configs_Handler.
A module which provides general utility methods
to manage the parameters of the config files
:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import numpy as np
import logging
import os
import re
import glob
from wf_psf.data.data_handler import DataHandler
from wf_psf.metrics.metrics_interface import evaluate_model
from wf_psf.plotting.plots_interface import plot_metrics
from wf_psf.psf_models import psf_models
from wf_psf.psf_models.psf_model_loader import load_trained_psf_model
from wf_psf.training import train
from wf_psf.utils.read_config import read_conf
logger = logging.getLogger(__name__)
CONFIG_CLASS = {}
[docs]
def register_configclass(config_class):
"""Register Config Class.
A wrapper function to register all config classes
in a dictionary.
Parameters
----------
config_class: type
Config Class
Returns
-------
config_class: type
Config class
"""
for id in config_class.ids:
CONFIG_CLASS[id] = config_class
return config_class
[docs]
def set_run_config(config_name):
"""Set Run Configuration Class.
A function to retrieve the appropriate configuration
class based on the provided config name.
Parameters
----------
config_name: str
Name of config
Returns
-------
config_class: class
Name of config class
"""
try:
config_id = [id for id in CONFIG_CLASS.keys() if re.search(id, config_name)][0]
config_class = CONFIG_CLASS[config_id]
except KeyError:
logger.exception("Invalid config name. Check your config settings.")
exit()
return config_class
[docs]
def get_run_config(run_config_name, *config_params):
"""Get Run Configuration Instance.
A function to retrieve an instance of
the appropriate configuration class for
a WF-PSF run.
Parameters
----------
run_config_name: str
Name of the run configuraton
*config_params: str
Run configuration parameters used for class instantiation.
Returns
-------
config_class: object
A class instance of the selected configuration class.
"""
config_class = set_run_config(run_config_name)
return config_class(*config_params)
[docs]
class ConfigParameterError(Exception):
"""Custom Config Parameter Error exception class for specific error scenarios."""
def __init__(self, message="An error with your config settings occurred."):
self.message = message
super().__init__(self.message)
[docs]
class DataConfigHandler:
"""DataConfigHandler.
A class to handle data configuration
parameters.
Parameters
----------
data_conf : str
Path of the data configuration file
training_model_params : Recursive Namespace object
Recursive Namespace object containing the training model parameters
batch_size : int
Training hyperparameter used for batched pre-processing of data.
"""
def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True):
try:
self.data_conf = read_conf(data_conf)
except (FileNotFoundError, TypeError) as e:
logger.exception(e)
exit()
self.simPSF = psf_models.simPSF(training_model_params)
# Extract sub-configs early
train_params = self.data_conf.data.training
test_params = self.data_conf.data.test
self.training_data = DataHandler(
dataset_type="training",
data_params=train_params,
simPSF=self.simPSF,
n_bins_lambda=training_model_params.n_bins_lda,
load_data=load_data,
)
self.test_data = DataHandler(
dataset_type="test",
data_params=test_params,
simPSF=self.simPSF,
n_bins_lambda=training_model_params.n_bins_lda,
load_data=load_data,
)
self.batch_size = batch_size
[docs]
@register_configclass
class TrainingConfigHandler:
"""TrainingConfigHandler.
A class to handle training configuration
parameters.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
training_conf: str
Path of the training configuration file
file_handler: object
A instance of the FileIOHandler class
"""
ids = ("training_conf",)
def __init__(self, training_conf, file_handler):
self.training_conf = read_conf(training_conf)
self.file_handler = file_handler
self.data_conf = DataConfigHandler(
os.path.join(
file_handler.config_path, self.training_conf.training.data_config
),
self.training_conf.training.model_params,
self.training_conf.training.training_hparams.batch_size,
self.training_conf.training.load_data_on_init,
)
self.data_conf.run_type = "training"
self.file_handler.copy_conffile_to_output_dir(
self.training_conf.training.data_config
)
self.checkpoint_dir = file_handler.get_checkpoint_dir(
self.file_handler._run_output_dir
)
self.optimizer_dir = file_handler.get_optimizer_dir(
self.file_handler._run_output_dir
)
self.psf_model_dir = file_handler.get_psf_model_dir(
self.file_handler._run_output_dir
)
[docs]
def run(self):
"""Run.
A function to run wavediff according to the
input configuration.
"""
train.train(
self.training_conf.training,
self.data_conf,
self.checkpoint_dir,
self.optimizer_dir,
self.psf_model_dir,
)
if self.training_conf.training.metrics_config is not None:
self.file_handler.copy_conffile_to_output_dir(
self.training_conf.training.metrics_config
)
metrics = MetricsConfigHandler(
os.path.join(
self.file_handler.config_path,
self.training_conf.training.metrics_config,
),
self.file_handler,
self.training_conf,
)
metrics.run()
[docs]
@register_configclass
class MetricsConfigHandler:
"""MetricsConfigHandler.
A class to handle metrics configuation
parameters.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
metrics_conf: str
Path to the metrics configuration file
file_handler: object
An instance of the FileIOHandler
training_conf: RecursiveNamespace object
RecursiveNamespace object containing the training configuration parameters
"""
ids = ("metrics_conf",)
def __init__(self, metrics_conf, file_handler, training_conf=None):
self._metrics_conf = read_conf(metrics_conf)
self._file_handler = file_handler
self.training_conf = training_conf
self.data_conf = self._load_data_conf()
self.data_conf.run_type = "metrics"
self.metrics_dir = self._file_handler.get_metrics_dir(
self._file_handler._run_output_dir
)
self.trained_psf_model = self._load_trained_psf_model()
@property
def metrics_conf(self):
"""Get Metrics Conf.
A function to return the metrics configuration file name.
Returns
-------
RecursiveNamespace
An instance of the metrics configuration file.
"""
return self._metrics_conf
@property
def training_conf(self):
"""Returns the loaded training configuration."""
return self._training_conf
@training_conf.setter
def training_conf(self, training_conf):
"""
Sets the training configuration. If None is provided, attempts to load it
from the trained_model_path in the metrics configuration.
"""
if training_conf is None:
try:
training_conf_path = self._get_training_conf_path_from_metrics()
logger.info(
f"Loading training config from inferred path: {training_conf_path}"
)
self._training_conf = read_conf(training_conf_path)
except Exception as e:
logger.error(f"Failed to load training config: {e}")
raise
else:
self._training_conf = training_conf
@property
def plotting_conf(self):
"""Get Plotting Conf.
A function to return the plotting configuration file name.
Returns
-------
str
Name of plotting configuration file
"""
return self.metrics_conf.metrics.plotting_config
def _load_trained_psf_model(self):
trained_model_path = self._get_trained_model_path()
try:
model_subdir = self.metrics_conf.metrics.model_save_path
cycle = self.metrics_conf.metrics.saved_training_cycle
except AttributeError as e:
raise KeyError("Missing required model config fields.") from e
model_name = self.training_conf.training.model_params.model_name
id_name = self.training_conf.training.id_name
weights_path_pattern = os.path.join(
trained_model_path,
model_subdir,
(f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"),
)
return load_trained_psf_model(
self.training_conf,
self.data_conf,
weights_path_pattern,
)
def _get_training_conf_path_from_metrics(self):
"""
Retrieves the full path to the training config based on the metrics configuration.
Returns
-------
str
Full path to the training configuration file.
Raises
------
KeyError
If 'trained_model_config' key is missing.
FileNotFoundError
If the file does not exist at the constructed path.
"""
trained_model_path = self._get_trained_model_path()
try:
training_conf_filename = self._metrics_conf.metrics.trained_model_config
except AttributeError as e:
raise KeyError(
"Missing 'trained_model_config' key in metrics configuration."
) from e
training_conf_path = os.path.join(
self._file_handler.get_config_dir(trained_model_path),
training_conf_filename,
)
if not os.path.exists(training_conf_path):
raise FileNotFoundError(
f"Training config file not found: {training_conf_path}"
)
return training_conf_path
def _get_trained_model_path(self):
"""
Determine the trained model path from either:
1. The metrics configuration file (i.e., for metrics-only runs after training), or
2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation).
Returns
-------
str
Path to the trained model directory.
Raises
------
ConfigParameterError
If the path specified in the metrics config is invalid or missing.
"""
trained_model_path = getattr(
self._metrics_conf.metrics, "trained_model_path", None
)
if trained_model_path:
if not os.path.isdir(trained_model_path):
raise ConfigParameterError(
f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}"
)
logger.info(
f"Using trained model path from metrics config: {trained_model_path}"
)
return trained_model_path
# Fallback for single-run training + metrics evaluation mode
fallback_path = os.path.join(
self._file_handler.output_path,
self._file_handler.parent_output_dir,
self._file_handler.workdir,
)
logger.info(
f"Using fallback trained model path from runtime file handler: {fallback_path}"
)
return fallback_path
def _load_data_conf(self):
"""Load Data Conf.
A method to load the data configuration file
and return an instance of DataConfigHandler class.
Returns
-------
An instance of the DataConfigHandler class.
"""
try:
return DataConfigHandler(
os.path.join(
self._file_handler.config_path,
self.training_conf.training.data_config,
),
self.training_conf.training.model_params,
)
except TypeError as e:
logger.exception(e)
raise ConfigParameterError("Data configuration loading error.")
[docs]
def call_plot_config_handler_run(self, model_metrics):
"""Make Metrics Plots.
A function to call the PlottingConfigHandler run
command to generate metrics plots.
Parameters
----------
model_metrics: dict
A dictionary storing the metrics
output generated during evaluation
of the trained PSF model.
"""
self._plotting_conf = os.path.join(
self._file_handler.config_path,
self.plotting_conf,
)
plots_config_handler = PlottingConfigHandler(
self._plotting_conf,
self._file_handler,
)
# Update metrics_confs dict with latest result
plots_config_handler.metrics_confs[self._file_handler.workdir] = (
self.metrics_conf
)
# Update metric results dict with latest result
plots_config_handler.list_of_metrics_dict[self._file_handler.workdir] = [
{
self.training_conf.training.model_params.model_name
+ self.training_conf.training.id_name: [model_metrics]
}
]
plots_config_handler.run()
[docs]
def run(self):
"""Run.
A function to run WaveDiff according to the
input configuration.
"""
logger.info("Running metrics evaluation on trained PSF model...")
model_metrics = evaluate_model(
self.metrics_conf.metrics,
self.training_conf.training,
self.data_conf,
self.trained_psf_model,
self.metrics_dir,
)
if self.plotting_conf is not None:
self._file_handler.copy_conffile_to_output_dir(
self.metrics_conf.metrics.plotting_config
)
self.call_plot_config_handler_run(model_metrics)
[docs]
@register_configclass
class PlottingConfigHandler:
"""PlottingConfigHandler.
A class to handle plotting config
settings.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
plotting_conf: str
Name of plotting configuration file
file_handler: obj
An instance of the FileIOHandler class
"""
ids = ("plotting_conf",)
def __init__(self, plotting_conf, file_handler):
self.plotting_conf = read_conf(plotting_conf)
self.file_handler = file_handler
self.metrics_confs = {}
self.check_and_update_metrics_confs()
self.list_of_metrics_dict = self.make_dict_of_metrics()
self.plots_dir = self.file_handler.get_plots_dir(
self.file_handler._run_output_dir
)
[docs]
def check_and_update_metrics_confs(self):
"""Check and Update Metrics Confs.
A function to check if user provided inputs metrics
dir to add to metrics_confs dictionary.
"""
if self.plotting_conf.plotting_params.metrics_dir:
self._update_metrics_confs()
[docs]
def make_dict_of_metrics(self):
"""Make dictionary of metrics.
A function to create a dictionary for each metrics per run.
Returns
-------
dict
A dictionary containing metrics or an empty dictionary.
"""
if self.plotting_conf.plotting_params.metrics_dir:
return self.load_metrics_into_dict()
else:
return {}
def _update_metrics_confs(self):
"""Update Metrics Configurations.
A method to update the metrics_confs dictionary
with each set of metrics configuration parameters
provided as input.
"""
for wf_dir, metrics_conf in zip(
self.plotting_conf.plotting_params.metrics_dir,
self.plotting_conf.plotting_params.metrics_config,
):
self.metrics_confs[wf_dir] = read_conf(
os.path.join(
self.plotting_conf.plotting_params.metrics_output_path,
self.file_handler.get_config_dir(wf_dir),
metrics_conf,
)
)
def _metrics_run_id_name(self, wf_outdir, metrics_params):
"""Get Metrics Run ID Name.
A function to generate run id name
for the metrics of a trained model
Parameters
----------
wf_outdir: str
Name of the wf-psf run output directory
metrics_params: RecursiveNamespace Object
RecursiveNamespace object containing the metrics parameters used to evaluated the trained model.
Returns
-------
metrics_run_id_name: list
List containing the model name and id for each training run
"""
try:
training_conf = read_conf(
os.path.join(
metrics_params.metrics.trained_model_path,
metrics_params.metrics.trained_model_config,
)
)
id_name = training_conf.training.id_name
model_name = training_conf.training.model_params.model_name
return [model_name + id_name]
except (TypeError, FileNotFoundError):
logger.info("Trained model path not provided...")
logger.info(
f"Trying to retrieve training config file from workdir: {wf_outdir}"
)
training_confs = [
read_conf(training_conf)
for training_conf in glob.glob(
os.path.join(
self.plotting_conf.plotting_params.metrics_output_path,
self.file_handler.get_config_dir(wf_outdir),
"training*",
)
)
]
run_ids = [
training_conf.training.model_params.model_name
+ training_conf.training.id_name
for training_conf in training_confs
]
return run_ids
except FileNotFoundError:
logger.exception("File not found.")
[docs]
def load_metrics_into_dict(self):
"""Load Metrics into Dictionary.
A method to load a metrics file of
a trained model from a previous run into a
dictionary.
Returns
-------
metrics_files_dict: dict
A dictionary containing all of the metrics from the loaded metrics files.
"""
metrics_dict = {}
for k, v in self.metrics_confs.items():
run_id_names = self._metrics_run_id_name(k, v)
metrics_dict[k] = []
for run_id_name in run_id_names:
output_path = os.path.join(
self.plotting_conf.plotting_params.metrics_output_path,
k,
"metrics",
"metrics-" + run_id_name + ".npy",
)
logger.info(
f"Attempting to read in trained model config file...{output_path}"
)
try:
metrics_dict[k].append(
{run_id_name: [np.load(output_path, allow_pickle=True)[()]]}
)
except FileNotFoundError:
logger.error(
"The required file for the plots was not found. Please check your configs settings."
)
return metrics_dict
[docs]
def run(self):
"""Run.
A function to run wave-diff according to the
input configuration.
"""
logger.info("Generating metric plots...")
plot_metrics(
self.plotting_conf.plotting_params,
self.list_of_metrics_dict,
self.metrics_confs,
self.plots_dir,
)