"""Configs_Handler.
A module which provides general utility methods
to manage the parameters of the config files
:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""
import numpy as np
from wf_psf.utils.read_config import read_conf
from wf_psf.data.training_preprocessing import TrainingDataHandler, TestDataHandler
from wf_psf.training import train
from wf_psf.psf_models import psf_models
from wf_psf.metrics.metrics_interface import evaluate_model
from wf_psf.plotting.plots_interface import plot_metrics
import logging
from wf_psf.utils.io import FileIOHandler
import os
import re
import glob
logger = logging.getLogger(__name__)
CONFIG_CLASS = {}
[docs]
def register_configclass(config_class):
"""Register Config Class.
A wrapper function to register all config classes
in a dictionary.
Parameters
----------
config_class: type
Config Class
Returns
-------
config_class: type
Config class
"""
for id in config_class.ids:
CONFIG_CLASS[id] = config_class
return config_class
[docs]
def set_run_config(config_name):
"""Set Config Class.
A function to select the class of
a configuration from CONFIG_CLASS dictionary.
Parameters
----------
config_name: str
Name of config
Returns
-------
config_class: class
Name of config class
"""
try:
config_id = [id for id in CONFIG_CLASS.keys() if re.search(id, config_name)][0]
config_class = CONFIG_CLASS[config_id]
except KeyError as e:
logger.exception("Config name entered is invalid. Check your config settings.")
exit()
return config_class
[docs]
def get_run_config(run_config, config_params, file_handler):
"""Get Run Configuration.
A function to get the configuration
for a wf-psf run.
Parameters
----------
run_config: str
Name of the type of run configuraton
config_params: str
Path of the run configuration file
file_handler: object
A class instance of FileIOHandler
Returns
-------
config_class: object
A class instance of the selected configuration class.
"""
config_class = set_run_config(run_config)
return config_class(config_params, file_handler)
[docs]
class ConfigParameterError(Exception):
"""Custom Config Parameter Error exception class for specific error scenarios."""
def __init__(self, message="An error with your config settings occurred."):
self.message = message
super().__init__(self.message)
[docs]
class DataConfigHandler:
"""DataConfigHandler.
A class to handle data configuration
parameters.
Parameters
----------
data_conf: str
Path of the data configuration file
training_model_params: Recursive Namespace object
Recursive Namespace object containing the training model parameters
"""
def __init__(self, data_conf, training_model_params):
try:
self.data_conf = read_conf(data_conf)
except FileNotFoundError as e:
logger.exception(e)
exit()
except TypeError as e:
logger.exception(e)
exit()
self.simPSF = psf_models.simPSF(training_model_params)
self.training_data = TrainingDataHandler(
self.data_conf.data.training,
self.simPSF,
training_model_params.n_bins_lda,
)
self.test_data = TestDataHandler(
self.data_conf.data.test,
self.simPSF,
training_model_params.n_bins_lda,
)
[docs]
@register_configclass
class TrainingConfigHandler:
"""TrainingConfigHandler.
A class to handle training configuration
parameters.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
training_conf: str
Path of the training configuration file
file_handler: object
A instance of the FileIOHandler class
"""
ids = ("training_conf",)
def __init__(self, training_conf, file_handler):
self.training_conf = read_conf(training_conf)
self.file_handler = file_handler
self.data_conf = DataConfigHandler(
os.path.join(
file_handler.config_path, self.training_conf.training.data_config
),
self.training_conf.training.model_params,
)
self.file_handler.copy_conffile_to_output_dir(
self.training_conf.training.data_config
)
self.checkpoint_dir = file_handler.get_checkpoint_dir(
self.file_handler._run_output_dir
)
self.optimizer_dir = file_handler.get_optimizer_dir(
self.file_handler._run_output_dir
)
self.psf_model_dir = file_handler.get_psf_model_dir(
self.file_handler._run_output_dir
)
[docs]
def run(self):
"""Run.
A function to run wave-diff according to the
input configuration.
"""
train.train(
self.training_conf.training,
self.data_conf.training_data,
self.data_conf.test_data,
self.checkpoint_dir,
self.optimizer_dir,
self.psf_model_dir,
)
if self.training_conf.training.metrics_config is not None:
self.file_handler.copy_conffile_to_output_dir(
self.training_conf.training.metrics_config
)
metrics = MetricsConfigHandler(
os.path.join(
self.file_handler.config_path,
self.training_conf.training.metrics_config,
),
self.file_handler,
self.training_conf,
)
metrics.run()
[docs]
@register_configclass
class MetricsConfigHandler:
"""MetricsConfigHandler.
A class to handle metrics configuation
parameters.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
metrics_conf: str
Path to the metrics configuration file
file_handler: object
An instance of the FileIOHandler
training_conf: RecursiveNamespace object
RecursiveNamespace object containing the training configuration parameters
"""
ids = ("metrics_conf",)
def __init__(self, metrics_conf, file_handler, training_conf=None):
self._metrics_conf = read_conf(metrics_conf)
self._file_handler = file_handler
self.trained_model_path = self._get_trained_model_path(training_conf)
self._training_conf = self._load_training_conf(training_conf)
@property
def metrics_conf(self):
return self._metrics_conf
@property
def metrics_dir(self):
return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir)
@property
def training_conf(self):
return self._training_conf
@property
def plotting_conf(self):
return self.metrics_conf.metrics.plotting_config
@property
def data_conf(self):
return self._load_data_conf()
@property
def psf_model(self):
return psf_models.get_psf_model(
self.training_conf.training.model_params,
self.training_conf.training.training_hparams,
)
@property
def weights_path(self):
return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath)
def _get_trained_model_path(self, training_conf):
"""Get Trained Model Path.
Helper method to get the trained model path.
Parameters
----------
training_conf: None or RecursiveNamespace
None type or RecursiveNamespace
Returns
-------
str
A string representing the path to the trained model output run directory.
"""
if training_conf is None:
try:
return self._metrics_conf.metrics.trained_model_path
except TypeError as e:
logger.exception(e)
raise ConfigParameterError(
"Metrics config file trained model path or config values are empty."
)
else:
return os.path.join(
self._file_handler.output_path,
self._file_handler.parent_output_dir,
self._file_handler.workdir,
)
def _load_training_conf(self, training_conf):
"""Load Training Conf.
Load the training configuration if training_conf is not provided.
Parameters
----------
training_conf: None or RecursiveNamespace
None type or a RecursiveNamespace storing the training configuration parameter setttings.
Returns
-------
RecursiveNamespace storing the training configuration parameter settings.
"""
if training_conf is None:
try:
return read_conf(
os.path.join(
self._file_handler.get_config_dir(self.trained_model_path),
self._metrics_conf.metrics.trained_model_config,
)
)
except TypeError as e:
logger.exception(e)
raise ConfigParameterError(
"Metrics config file trained model path or config values are empty."
)
else:
return training_conf
def _load_data_conf(self):
"""Load Data Conf.
A method to load the data configuration file
and return an instance of DataConfigHandler class.
Returns
-------
An instance of the DataConfigHandler class.
"""
try:
return DataConfigHandler(
os.path.join(
self._file_handler.config_path,
self.training_conf.training.data_config,
),
self.training_conf.training.model_params,
)
except TypeError as e:
logger.exception(e)
raise ConfigParameterError("Data configuration loading error.")
@property
def weights_basename_filepath(self):
"""Get PSF model weights filepath.
A function to return the basename of the user-specified psf model weights path.
Returns
-------
weights_basename: str
The basename of the psf model weights to be loaded.
"""
return os.path.join(
self.trained_model_path,
self.metrics_conf.metrics.model_save_path,
(
f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}"
f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*"
),
)
[docs]
def call_plot_config_handler_run(self, model_metrics):
"""Make Metrics Plots.
A function to call the PlottingConfigHandler run
command to generate metrics plots.
Parameters
----------
model_metrics: dict
A dictionary storing the metrics
output generated during evaluation
of the trained PSF model.
"""
self._plotting_conf = os.path.join(
self._file_handler.config_path,
self.plotting_conf,
)
plots_config_handler = PlottingConfigHandler(
self._plotting_conf,
self._file_handler,
)
# Update metrics_confs dict with latest result
plots_config_handler.metrics_confs[
self._file_handler.workdir
] = self.metrics_conf
# Update metric results dict with latest result
plots_config_handler.list_of_metrics_dict[self._file_handler.workdir] = [
{
self.training_conf.training.model_params.model_name
+ self.training_conf.training.id_name: [model_metrics]
}
]
plots_config_handler.run()
[docs]
def run(self):
"""Run.
A function to run wave-diff according to the
input configuration.
"""
logger.info(
"Running metrics evaluation on psf model: {}".format(self.weights_path)
)
model_metrics = evaluate_model(
self.metrics_conf.metrics,
self.training_conf.training,
self.data_conf.training_data,
self.data_conf.test_data,
self.psf_model,
self.weights_path,
self.metrics_dir,
)
if self.plotting_conf is not None:
self._file_handler.copy_conffile_to_output_dir(
self.metrics_conf.metrics.plotting_config
)
self.call_plot_config_handler_run(model_metrics)
[docs]
@register_configclass
class PlottingConfigHandler:
"""PlottingConfigHandler.
A class to handle plotting config
settings.
Parameters
----------
ids: tuple
A tuple containing a string id for the Configuration Class
plotting_conf: str
Name of plotting configuration file
file_handler: obj
An instance of the FileIOHandler class
"""
ids = ("plotting_conf",)
def __init__(self, plotting_conf, file_handler):
self.plotting_conf = read_conf(plotting_conf)
self.file_handler = file_handler
self.metrics_confs = {}
self.check_and_update_metrics_confs()
self.list_of_metrics_dict = self.make_dict_of_metrics()
self.plots_dir = self.file_handler.get_plots_dir(
self.file_handler._run_output_dir
)
[docs]
def check_and_update_metrics_confs(self):
"""Check and Update Metrics Confs.
A function to check if user provided inputs metrics
dir to add to metrics_confs dictionary.
"""
if self.plotting_conf.plotting_params.metrics_dir:
self._update_metrics_confs()
[docs]
def make_dict_of_metrics(self):
"""Make dictionary of metrics.
A function to create a dictionary for each metrics per run.
Returns
-------
dict
A dictionary containing metrics or an empty dictionary.
"""
if self.plotting_conf.plotting_params.metrics_dir:
return self.load_metrics_into_dict()
else:
return {}
def _update_metrics_confs(self):
"""Update Metrics Configurations.
A method to update the metrics_confs dictionary
with each set of metrics configuration parameters
provided as input.
"""
for wf_dir, metrics_conf in zip(
self.plotting_conf.plotting_params.metrics_dir,
self.plotting_conf.plotting_params.metrics_config,
):
self.metrics_confs[wf_dir] = read_conf(
os.path.join(
self.plotting_conf.plotting_params.metrics_output_path,
self.file_handler.get_config_dir(wf_dir),
metrics_conf,
)
)
def _metrics_run_id_name(self, wf_outdir, metrics_params):
"""Get Metrics Run ID Name.
A function to generate run id name
for the metrics of a trained model
Parameters
----------
wf_outdir: str
Name of the wf-psf run output directory
metrics_params: RecursiveNamespace Object
RecursiveNamespace object containing the metrics parameters used to evaluated the trained model.
Returns
-------
metrics_run_id_name: list
List containing the model name and id for each training run
"""
try:
training_conf = read_conf(
os.path.join(
metrics_params.metrics.trained_model_path,
metrics_params.metrics.trained_model_config,
)
)
id_name = training_conf.training.id_name
model_name = training_conf.training.model_params.model_name
return [model_name + id_name]
except (TypeError, FileNotFoundError):
logger.info("Trained model path not provided...")
logger.info(
"Trying to retrieve training config file from workdir: {}".format(
wf_outdir
)
)
training_confs = [
read_conf(training_conf)
for training_conf in glob.glob(
os.path.join(
self.plotting_conf.plotting_params.metrics_output_path,
self.file_handler.get_config_dir(wf_outdir),
"training*",
)
)
]
run_ids = [
training_conf.training.model_params.model_name
+ training_conf.training.id_name
for training_conf in training_confs
]
return run_ids
except:
logger.exception("File not found.")
[docs]
def load_metrics_into_dict(self):
"""Load Metrics into Dictionary.
A method to load a metrics file of
a trained model from a previous run into a
dictionary.
Returns
-------
metrics_files_dict: dict
A dictionary containing all of the metrics from the loaded metrics files.
"""
metrics_dict = {}
for k, v in self.metrics_confs.items():
run_id_names = self._metrics_run_id_name(k, v)
metrics_dict[k] = []
for run_id_name in run_id_names:
output_path = os.path.join(
self.plotting_conf.plotting_params.metrics_output_path,
k,
"metrics",
"metrics-" + run_id_name + ".npy",
)
logger.info(
"Attempting to read in trained model config file...{}".format(
output_path
)
)
try:
metrics_dict[k].append(
{run_id_name: [np.load(output_path, allow_pickle=True)[()]]}
)
except FileNotFoundError:
logger.error(
"The required file for the plots was not found. Please check your configs settings."
)
return metrics_dict
[docs]
def run(self):
"""Run.
A function to run wave-diff according to the
input configuration.
"""
logger.info("Generating metric plots...")
plot_metrics(
self.plotting_conf.plotting_params,
self.list_of_metrics_dict,
self.metrics_confs,
self.plots_dir,
)