"""Inference.
A module which provides a `PSFInference` class to perform inference
with trained PSF models. It is able to load a trained model,
perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs.
:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>
"""
import os
from pathlib import Path
import numpy as np
from typing import Optional
from wf_psf.data.data_adapter import StructureState, RepresentationState
from wf_psf.data.data_config_handler import DataConfigHandler
from wf_psf.data.data_adapter import DataAdapter
from wf_psf.data.factory import DataAdapterFactory
from wf_psf.data.schemas import DatasetMode
from wf_psf.utils.read_config import read_conf
from wf_psf.psf_models import psf_models
from wf_psf.psf_models.psf_model_loader import load_trained_psf_model
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
[docs]
class InferenceConfigHandler:
"""
Handle configuration loading and management for PSF inference.
This class manages the loading of inference, training, and data configuration
files required for PSF inference operations.
Parameters
----------
inference_config_path : str
Path to the inference configuration YAML file.
Attributes
----------
inference_config_path : str
Path to the inference configuration file.
inference_config : RecursiveNamespace or None
Loaded inference configuration.
training_config : RecursiveNamespace or None
Loaded training configuration.
data_config : RecursiveNamespace or None
Loaded data configuration.
trained_model_path : Path
Path to the trained model directory.
model_subdir : str
Subdirectory name for model files.
trained_model_config_path : Path
Path to the training configuration file.
data_config_path : str or None
Path to the data configuration file.
"""
ids = ("inference_conf",)
def __init__(self, inference_config_path: str):
self.inference_config_path = inference_config_path
self.inference_config = None
self.training_config = None
self.data_config = None
[docs]
def load_configs(self):
"""
Load configuration files based on the inference config.
Loads the inference configuration first, then uses it to determine and load
the training and data configurations.
Notes
-----
Updates the following attributes in-place:
- `inference_config`
- `training_config`
- `data_config` (if `data_config_path` is specified)
"""
self.inference_config = read_conf(self.inference_config_path).inference
self.set_config_paths()
self.training_config = read_conf(self.trained_model_config_path).training
if self.data_config_path is not None:
# Load the data configuration
self.data_config = DataConfigHandler(self.data_config_path)
[docs]
def set_config_paths(self):
"""
Extract and set the configuration paths from the inference config.
Sets the following attributes:
- trained_model_path
- model_subdir
- trained_model_config_path
- data_config_path
"""
# Set config paths
config_paths = self.inference_config.configs
self.trained_model_path = Path(config_paths.trained_model_path)
self.model_subdir = config_paths.model_subdir
self.trained_model_config_path = (
self.trained_model_path / config_paths.trained_model_config_path
)
self.data_config_path = self.trained_model_path / config_paths.data_config_path
@property
def schema_mode(self) -> DatasetMode:
raw = self.inference_config.schema_mode.upper()
try:
return DatasetMode[raw]
except KeyError:
raise ValueError(...)
[docs]
@staticmethod
def overwrite_model_params(training_config=None, inference_config=None):
"""
Overwrite training model_params with values from inference_config if available.
Parameters
----------
training_config : RecursiveNamespace
Configuration object from training phase.
inference_config : RecursiveNamespace
Configuration object from inference phase.
Notes
-----
Updates are applied in-place to ``training_config.model_params``.
"""
model_params = training_config.model_params
inf_model_params = inference_config.model_params
if model_params and inf_model_params:
for key, value in inf_model_params.__dict__.items():
if hasattr(model_params, key):
setattr(model_params, key, value)
[docs]
class PSFInference:
"""
Perform PSF inference using a pre-trained WaveDiff model.
This class handles the setup for PSF inference, including loading configuration
files, instantiating the PSF simulator and data handler, and preparing the
input data required for inference.
Parameters
----------
inference_config_path : str
Path to the inference configuration YAML file.
x_field : array-like, optional
x coordinates in SHE convention.
y_field : array-like, optional
y coordinates in SHE convention.
seds : array-like, optional
Spectral energy distributions (SEDs).
sources : array-like, optional
Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]).
masks : array-like, optional
Corresponding masks for the sources (same shape as sources). Defaults to None.
Attributes
----------
inference_config_path : str
Path to the inference configuration file.
x_field : array-like or None
x coordinates for PSF positions.
y_field : array-like or None
y coordinates for PSF positions.
seds : array-like or None
Spectral energy distributions.
sources : array-like or None
Source postage stamps.
masks : array-like or None
Source masks.
engine : PSFInferenceEngine or None
The inference engine instance.
Examples
--------
Basic usage with position coordinates and SEDs:
.. code-block:: python
psf_inf = PSFInference(
inference_config_path="config.yaml", x_field=[100.5, 200.3], y_field=[150.2, 250.8], seds=sed_array
)
psf_inf.run_inference()
psf = psf_inf.get_psf(0)
"""
def __init__(
self,
inference_config_path: str,
x_field=None,
y_field=None,
seds=None,
sources=None,
masks=None,
):
self.inference_config_path = inference_config_path
# Inputs for the model
self.x_field = x_field
self.y_field = y_field
self.seds = seds
self.sources = sources
self.masks = masks
# Internal caches for lazy-loading
self._config_handler = None
self._simPSF = None
self._trained_psf_model = None
self._n_bins_lambda = None
self._batch_size = None
self._cycle = None
self._output_dim = None
# Initialise Data Adapters
self._model_data_adapter: Optional[DataAdapter] = None
self._inference_data_adapter: Optional[DataAdapter] = None
# Initialise PSF Inference engine
self.engine = None
@property
def config_handler(self):
"""
Get or create the configuration handler.
Returns
-------
InferenceConfigHandler
The configuration handler instance with loaded configs.
"""
if self._config_handler is None:
self._config_handler = InferenceConfigHandler(self.inference_config_path)
self._config_handler.load_configs()
return self._config_handler
[docs]
def prepare_configs(self):
"""
Prepare the configuration for inference.
Overwrites training model parameters with inference configuration values.
"""
# Overwrite model parameters with inference config
self.config_handler.overwrite_model_params(
self.training_config,
self.inference_config,
)
@property
def inference_config(self):
"""
Get the inference configuration.
Returns
-------
RecursiveNamespace
The inference configuration object.
"""
return self.config_handler.inference_config
@property
def training_config(self):
"""
Get the training configuration.
Returns
-------
RecursiveNamespace
The training configuration object.
"""
return self.config_handler.training_config
@property
def data_config(self):
"""
Get the data configuration.
Returns
-------
RecursiveNamespace or None
The data configuration object, or None if not available.
"""
return self.config_handler.data_config
@property
def simPSF(self):
"""
Get or create the PSF simulator.
Returns
-------
simPSF
The PSF simulator instance.
"""
if self._simPSF is None:
self._simPSF = psf_models.simPSF(self.training_config.model_params)
return self._simPSF
def _prepare_dataset_for_inference(self):
"""
Prepare the input dataset dictionary for inference.
Returns
-------
dict
Dictionary containing canonical fields for inference:
- positions
- sources
- masks (optional)
- seds (optional)
"""
positions = self.get_positions()
if positions is None:
raise ValueError(
"x_field and y_field must be provided for inference positions."
)
# Only include fields that are not None
dataset = {"positions": positions}
if self.sources is not None:
dataset["sources"] = self.sources
if self.masks is not None:
dataset["masks"] = self.masks
if self.seds is not None:
dataset["seds"] = self.seds
return dataset
@property
def model_data_adapter(self):
"""
Create and return a Model DataAdapter for loading trained PSF model using the factory.
Returns
-------
DataAdapter
A fully prepared model data adapter with LoadedDataset.
"""
if self._model_data_adapter is None:
logger.info("Generating the model data adapter...")
dataset_params = self.data_config
# Use the factory — it will normalize, convert dicts/dataclasses, and produce LoadedDataset
adapter = DataAdapterFactory.build(dataset_params)
# Join data, if not already complete
if adapter.structure_state == StructureState.SPLIT:
logger.info("Joining split datasets...")
adapter.join_data()
self._model_data_adapter = adapter
return self._model_data_adapter
@property
def inference_data_adapter(self):
"""
Create and return a DataAdapter for inference data using the factory.
Returns
-------
DataAdapter
A fully prepared data adapter with LoadedDataset ready for inference.
"""
if self._inference_data_adapter is None:
logger.info("Generating the inference data adapter...")
dataset = self._prepare_dataset_for_inference()
# Use the factory — it will normalize, convert dicts/dataclasses, and produce LoadedDataset
adapter = DataAdapterFactory.build(dataset)
self._inference_data_adapter = adapter
return self._inference_data_adapter
def _convert_inference_data_to_tensorflow(self):
# Convert to TensorFlow according to dataset schema mode
if self._inference_data_adapter.representation_state == RepresentationState.NUMPY:
self._inference_data_adapter.convert_to_tensorflow(
self.simPSF,
self.n_bins_lambda,
mode=self.config_handler.schema_mode,
)
@property
def trained_psf_model(self):
"""
Get or load the trained PSF model.
Returns
-------
Model
The loaded trained PSF model.
"""
if self._trained_psf_model is None:
self._trained_psf_model = self.load_inference_model()
return self._trained_psf_model
[docs]
def get_positions(self):
"""
Combine x_field and y_field into position pairs.
Returns
-------
numpy.ndarray
Array of shape (num_positions, 2) where each row contains [x, y] coordinates.
Returns None if either x_field or y_field is None.
Raises
------
ValueError
If x_field and y_field have different lengths.
"""
if self.x_field is None or self.y_field is None:
return None
x_arr = np.asarray(self.x_field)
y_arr = np.asarray(self.y_field)
if x_arr.size == 0 or y_arr.size == 0:
return None
if x_arr.size != y_arr.size:
raise ValueError(
f"x_field and y_field must have the same length. "
f"Got {x_arr.size} and {y_arr.size}"
)
# Flatten arrays to handle any input shape, then stack
x_flat = x_arr.flatten()
y_flat = y_arr.flatten()
return np.column_stack((x_flat, y_flat))
[docs]
def load_inference_model(self):
"""Load the trained PSF model based on the inference configuration.
Returns
-------
Model
The loaded trained PSF model.
Notes
-----
Constructs the weights path pattern based on the trained model path,
model subdirectory, model name, id name, and cycle number specified in the
configuration files.
"""
model_path = self.config_handler.trained_model_path
model_dir = self.config_handler.model_subdir
model_name = self.training_config.model_params.model_name
id_name = self.training_config.id_name
weights_path_pattern = os.path.join(
model_path,
model_dir,
f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*",
)
# Load the trained PSF model
return load_trained_psf_model(
self.training_config,
self.inference_data_adapter.complete_data,
weights_path_pattern,
)
@property
def n_bins_lambda(self):
"""Get the number of wavelength bins for inference.
Returns
-------
int
The number of wavelength bins used during inference.
"""
if self._n_bins_lambda is None:
self._n_bins_lambda = self.inference_config.model_params.n_bins_lambda
return self._n_bins_lambda
@property
def batch_size(self):
"""
Get the batch size for inference.
Returns
-------
int
The batch size for processing during inference.
"""
if self._batch_size is None:
self._batch_size = self.inference_config.batch_size
assert self._batch_size > 0, "Batch size must be greater than 0."
return self._batch_size
@property
def cycle(self):
"""Get the cycle number for inference.
Returns
-------
int
The cycle number used for loading the trained model.
"""
if self._cycle is None:
self._cycle = self.inference_config.cycle
return self._cycle
@property
def output_dim(self):
"""Get the output dimension for PSF inference.
Returns
-------
int
The output dimension (height and width) of the inferred PSFs.
"""
if self._output_dim is None:
self._output_dim = self.inference_config.model_params.output_dim
return self._output_dim
[docs]
def run_inference(self):
"""Run PSF inference and return the full PSF array.
Returns
-------
numpy.ndarray
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
Notes
-----
Prepares configurations and input data, initializes the inference engine,
and computes the PSF for all input positions.
"""
# Prepare the configuration for inference
self.prepare_configs()
self.engine = PSFInferenceEngine(
trained_model=self.trained_psf_model,
batch_size=self.batch_size,
output_dim=self.output_dim,
)
# Convert inference data to tensorflow type
self._convert_inference_data_to_tensorflow()
# Get positions and SEDs
positions = self.inference_data_adapter.complete_data["positions"]
sed_data = self.inference_data_adapter.complete_data["seds"]
return self.engine.compute_psfs(positions, sed_data)
def _ensure_psf_inference_completed(self):
"""Ensure that PSF inference has been completed.
Runs inference if it has not been done yet.
"""
if self.engine is None or self.engine.inferred_psfs is None:
self.run_inference()
[docs]
def get_psfs(self):
"""Get all inferred PSFs.
Returns
-------
numpy.ndarray
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
Notes
-----
Ensures automatically that inference has been completed before accessing the PSFs.
"""
self._ensure_psf_inference_completed()
return self.engine.get_psfs()
[docs]
def get_psf(self, index: int = 0) -> np.ndarray:
"""
Get the PSF at a specific index.
Parameters
----------
index : int, optional
Index of the PSF to retrieve (default is 0).
Returns
-------
numpy.ndarray
The inferred PSF at the specified index with shape (output_dim, output_dim).
Notes
-----
Ensures automatically that inference has been completed before accessing the PSF.
If only a single star was passed during instantiation, the index defaults to 0
and bounds checking is relaxed.
"""
self._ensure_psf_inference_completed()
inferred_psfs = self.engine.get_psfs()
# If a single-star batch, ignore index bounds
if inferred_psfs.shape[0] == 1:
return inferred_psfs[0]
# Otherwise, return the PSF at the requested index
return inferred_psfs[index]
[docs]
def clear_cache(self):
"""
Clear all cached properties and reset the instance.
This method resets all lazy-loaded properties, including the config handler,
PSF simulator, data handler, trained model, and inference engine. Useful for
freeing memory or forcing a fresh initialization.
Notes
-----
After calling this method, accessing any property will trigger re-initialization.
"""
self._config_handler = None
self._simPSF = None
self._data_adapter = None
self._trained_psf_model = None
self._n_bins_lambda = None
self._batch_size = None
self._cycle = None
self._output_dim = None
self.engine = None
[docs]
class PSFInferenceEngine:
"""Engine to perform PSF inference using a trained model.
This class handles the batch-wise computation of PSFs using a trained PSF model.
It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access.
Parameters
----------
trained_model : Model
The trained PSF model to use for inference.
batch_size : int
The batch size for processing during inference.
output_dim : int
The output dimension (height and width) of the inferred PSFs.
Attributes
----------
trained_model : Model
The trained PSF model used for inference.
batch_size : int
The batch size for processing during inference.
output_dim : int
The output dimension (height and width) of the inferred PSFs.
Examples
--------
.. code-block:: python
engine = PSFInferenceEngine(model, batch_size=32, output_dim=64)
psfs = engine.compute_psfs(positions, seds)
single_psf = engine.get_psf(0)
"""
def __init__(self, trained_model, batch_size: int, output_dim: int):
self.trained_model = trained_model
self.batch_size = batch_size
self.output_dim = output_dim
self._inferred_psfs = None
@property
def inferred_psfs(self) -> np.ndarray:
"""Access the cached inferred PSFs, if available.
Returns
-------
numpy.ndarray or None
The cached inferred PSFs, or None if not yet computed.
"""
return self._inferred_psfs
[docs]
def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray:
"""Compute and cache PSFs for the input source parameters.
Parameters
----------
positions : tf.Tensor
Tensor of shape (n_samples, 2) containing the (x, y) positions
sed_data : tf.Tensor
Tensor of shape (n_samples, n_bins, 2) containing the SEDs
Returns
-------
numpy.ndarray
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
Notes
-----
PSFs are computed in batches according to the specified batch_size.
Results are cached internally for subsequent access via get_psfs() or get_psf().
"""
n_samples = positions.shape[0]
self._inferred_psfs = np.zeros(
(n_samples, self.output_dim, self.output_dim), dtype=np.float32
)
# Initialize counter
counter = 0
while counter < n_samples:
# Calculate the batch end element
end_sample = min(counter + self.batch_size, n_samples)
# Define the batch positions
batch_pos = positions[counter:end_sample, :]
batch_seds = sed_data[counter:end_sample, :, :]
batch_inputs = [batch_pos, batch_seds]
# Generate PSFs for the current batch
batch_psfs = self.trained_model(batch_inputs, training=False)
self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy()
# Update the counter
counter = end_sample
return self._inferred_psfs
[docs]
def get_psfs(self) -> np.ndarray:
"""Get all the generated PSFs.
Returns
-------
numpy.ndarray
Array of inferred PSFs with shape (n_samples, output_dim, output_dim).
"""
if self._inferred_psfs is None:
raise ValueError("PSFs not yet computed. Call compute_psfs() first.")
return self._inferred_psfs
[docs]
def get_psf(self, index: int) -> np.ndarray:
"""Get the PSF at a specific index.
Returns
-------
numpy.ndarray
The inferred PSF at the specified index with shape (output_dim, output_dim).
Raises
------
ValueError
If PSFs have not yet been computed.
"""
if self._inferred_psfs is None:
raise ValueError("PSFs not yet computed. Call compute_psfs() first.")
return self._inferred_psfs[index]
[docs]
def clear_cache(self):
"""
Clear cached inferred PSFs.
Resets the internal PSF cache to free memory. After calling this method,
compute_psfs() must be called again before accessing PSFs.
"""
self._inferred_psfs = None