Source code for wf_psf.inference.psf_inference

"""Inference.

A module which provides a `PSFInference` class to perform inference
with trained PSF models. It is able to load a trained model,
perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs.

:Authors: Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>

"""

import os
from pathlib import Path
import numpy as np
from typing import Optional
from wf_psf.data.data_adapter import StructureState, RepresentationState
from wf_psf.data.data_config_handler import DataConfigHandler
from wf_psf.data.data_adapter import DataAdapter
from wf_psf.data.factory import DataAdapterFactory
from wf_psf.data.schemas import DatasetMode
from wf_psf.utils.read_config import read_conf
from wf_psf.psf_models import psf_models
from wf_psf.psf_models.psf_model_loader import load_trained_psf_model
import tensorflow as tf
import logging

logger = logging.getLogger(__name__)


[docs] class InferenceConfigHandler: """ Handle configuration loading and management for PSF inference. This class manages the loading of inference, training, and data configuration files required for PSF inference operations. Parameters ---------- inference_config_path : str Path to the inference configuration YAML file. Attributes ---------- inference_config_path : str Path to the inference configuration file. inference_config : RecursiveNamespace or None Loaded inference configuration. training_config : RecursiveNamespace or None Loaded training configuration. data_config : RecursiveNamespace or None Loaded data configuration. trained_model_path : Path Path to the trained model directory. model_subdir : str Subdirectory name for model files. trained_model_config_path : Path Path to the training configuration file. data_config_path : str or None Path to the data configuration file. """ ids = ("inference_conf",) def __init__(self, inference_config_path: str): self.inference_config_path = inference_config_path self.inference_config = None self.training_config = None self.data_config = None
[docs] def load_configs(self): """ Load configuration files based on the inference config. Loads the inference configuration first, then uses it to determine and load the training and data configurations. Notes ----- Updates the following attributes in-place: - `inference_config` - `training_config` - `data_config` (if `data_config_path` is specified) """ self.inference_config = read_conf(self.inference_config_path).inference self.set_config_paths() self.training_config = read_conf(self.trained_model_config_path).training if self.data_config_path is not None: # Load the data configuration self.data_config = DataConfigHandler(self.data_config_path)
[docs] def set_config_paths(self): """ Extract and set the configuration paths from the inference config. Sets the following attributes: - trained_model_path - model_subdir - trained_model_config_path - data_config_path """ # Set config paths config_paths = self.inference_config.configs self.trained_model_path = Path(config_paths.trained_model_path) self.model_subdir = config_paths.model_subdir self.trained_model_config_path = ( self.trained_model_path / config_paths.trained_model_config_path ) self.data_config_path = self.trained_model_path / config_paths.data_config_path
@property def schema_mode(self) -> DatasetMode: raw = self.inference_config.schema_mode.upper() try: return DatasetMode[raw] except KeyError: raise ValueError(...)
[docs] @staticmethod def overwrite_model_params(training_config=None, inference_config=None): """ Overwrite training model_params with values from inference_config if available. Parameters ---------- training_config : RecursiveNamespace Configuration object from training phase. inference_config : RecursiveNamespace Configuration object from inference phase. Notes ----- Updates are applied in-place to ``training_config.model_params``. """ model_params = training_config.model_params inf_model_params = inference_config.model_params if model_params and inf_model_params: for key, value in inf_model_params.__dict__.items(): if hasattr(model_params, key): setattr(model_params, key, value)
[docs] class PSFInference: """ Perform PSF inference using a pre-trained WaveDiff model. This class handles the setup for PSF inference, including loading configuration files, instantiating the PSF simulator and data handler, and preparing the input data required for inference. Parameters ---------- inference_config_path : str Path to the inference configuration YAML file. x_field : array-like, optional x coordinates in SHE convention. y_field : array-like, optional y coordinates in SHE convention. seds : array-like, optional Spectral energy distributions (SEDs). sources : array-like, optional Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). masks : array-like, optional Corresponding masks for the sources (same shape as sources). Defaults to None. Attributes ---------- inference_config_path : str Path to the inference configuration file. x_field : array-like or None x coordinates for PSF positions. y_field : array-like or None y coordinates for PSF positions. seds : array-like or None Spectral energy distributions. sources : array-like or None Source postage stamps. masks : array-like or None Source masks. engine : PSFInferenceEngine or None The inference engine instance. Examples -------- Basic usage with position coordinates and SEDs: .. code-block:: python psf_inf = PSFInference( inference_config_path="config.yaml", x_field=[100.5, 200.3], y_field=[150.2, 250.8], seds=sed_array ) psf_inf.run_inference() psf = psf_inf.get_psf(0) """ def __init__( self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None, ): self.inference_config_path = inference_config_path # Inputs for the model self.x_field = x_field self.y_field = y_field self.seds = seds self.sources = sources self.masks = masks # Internal caches for lazy-loading self._config_handler = None self._simPSF = None self._trained_psf_model = None self._n_bins_lambda = None self._batch_size = None self._cycle = None self._output_dim = None # Initialise Data Adapters self._model_data_adapter: Optional[DataAdapter] = None self._inference_data_adapter: Optional[DataAdapter] = None # Initialise PSF Inference engine self.engine = None @property def config_handler(self): """ Get or create the configuration handler. Returns ------- InferenceConfigHandler The configuration handler instance with loaded configs. """ if self._config_handler is None: self._config_handler = InferenceConfigHandler(self.inference_config_path) self._config_handler.load_configs() return self._config_handler
[docs] def prepare_configs(self): """ Prepare the configuration for inference. Overwrites training model parameters with inference configuration values. """ # Overwrite model parameters with inference config self.config_handler.overwrite_model_params( self.training_config, self.inference_config, )
@property def inference_config(self): """ Get the inference configuration. Returns ------- RecursiveNamespace The inference configuration object. """ return self.config_handler.inference_config @property def training_config(self): """ Get the training configuration. Returns ------- RecursiveNamespace The training configuration object. """ return self.config_handler.training_config @property def data_config(self): """ Get the data configuration. Returns ------- RecursiveNamespace or None The data configuration object, or None if not available. """ return self.config_handler.data_config @property def simPSF(self): """ Get or create the PSF simulator. Returns ------- simPSF The PSF simulator instance. """ if self._simPSF is None: self._simPSF = psf_models.simPSF(self.training_config.model_params) return self._simPSF def _prepare_dataset_for_inference(self): """ Prepare the input dataset dictionary for inference. Returns ------- dict Dictionary containing canonical fields for inference: - positions - sources - masks (optional) - seds (optional) """ positions = self.get_positions() if positions is None: raise ValueError( "x_field and y_field must be provided for inference positions." ) # Only include fields that are not None dataset = {"positions": positions} if self.sources is not None: dataset["sources"] = self.sources if self.masks is not None: dataset["masks"] = self.masks if self.seds is not None: dataset["seds"] = self.seds return dataset @property def model_data_adapter(self): """ Create and return a Model DataAdapter for loading trained PSF model using the factory. Returns ------- DataAdapter A fully prepared model data adapter with LoadedDataset. """ if self._model_data_adapter is None: logger.info("Generating the model data adapter...") dataset_params = self.data_config # Use the factory — it will normalize, convert dicts/dataclasses, and produce LoadedDataset adapter = DataAdapterFactory.build(dataset_params) # Join data, if not already complete if adapter.structure_state == StructureState.SPLIT: logger.info("Joining split datasets...") adapter.join_data() self._model_data_adapter = adapter return self._model_data_adapter @property def inference_data_adapter(self): """ Create and return a DataAdapter for inference data using the factory. Returns ------- DataAdapter A fully prepared data adapter with LoadedDataset ready for inference. """ if self._inference_data_adapter is None: logger.info("Generating the inference data adapter...") dataset = self._prepare_dataset_for_inference() # Use the factory — it will normalize, convert dicts/dataclasses, and produce LoadedDataset adapter = DataAdapterFactory.build(dataset) self._inference_data_adapter = adapter return self._inference_data_adapter def _convert_inference_data_to_tensorflow(self): # Convert to TensorFlow according to dataset schema mode if self._inference_data_adapter.representation_state == RepresentationState.NUMPY: self._inference_data_adapter.convert_to_tensorflow( self.simPSF, self.n_bins_lambda, mode=self.config_handler.schema_mode, ) @property def trained_psf_model(self): """ Get or load the trained PSF model. Returns ------- Model The loaded trained PSF model. """ if self._trained_psf_model is None: self._trained_psf_model = self.load_inference_model() return self._trained_psf_model
[docs] def get_positions(self): """ Combine x_field and y_field into position pairs. Returns ------- numpy.ndarray Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None. Raises ------ ValueError If x_field and y_field have different lengths. """ if self.x_field is None or self.y_field is None: return None x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) if x_arr.size == 0 or y_arr.size == 0: return None if x_arr.size != y_arr.size: raise ValueError( f"x_field and y_field must have the same length. " f"Got {x_arr.size} and {y_arr.size}" ) # Flatten arrays to handle any input shape, then stack x_flat = x_arr.flatten() y_flat = y_arr.flatten() return np.column_stack((x_flat, y_flat))
[docs] def load_inference_model(self): """Load the trained PSF model based on the inference configuration. Returns ------- Model The loaded trained PSF model. Notes ----- Constructs the weights path pattern based on the trained model path, model subdirectory, model name, id name, and cycle number specified in the configuration files. """ model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.model_params.model_name id_name = self.training_config.id_name weights_path_pattern = os.path.join( model_path, model_dir, f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", ) # Load the trained PSF model return load_trained_psf_model( self.training_config, self.inference_data_adapter.complete_data, weights_path_pattern, )
@property def n_bins_lambda(self): """Get the number of wavelength bins for inference. Returns ------- int The number of wavelength bins used during inference. """ if self._n_bins_lambda is None: self._n_bins_lambda = self.inference_config.model_params.n_bins_lambda return self._n_bins_lambda @property def batch_size(self): """ Get the batch size for inference. Returns ------- int The batch size for processing during inference. """ if self._batch_size is None: self._batch_size = self.inference_config.batch_size assert self._batch_size > 0, "Batch size must be greater than 0." return self._batch_size @property def cycle(self): """Get the cycle number for inference. Returns ------- int The cycle number used for loading the trained model. """ if self._cycle is None: self._cycle = self.inference_config.cycle return self._cycle @property def output_dim(self): """Get the output dimension for PSF inference. Returns ------- int The output dimension (height and width) of the inferred PSFs. """ if self._output_dim is None: self._output_dim = self.inference_config.model_params.output_dim return self._output_dim
[docs] def run_inference(self): """Run PSF inference and return the full PSF array. Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). Notes ----- Prepares configurations and input data, initializes the inference engine, and computes the PSF for all input positions. """ # Prepare the configuration for inference self.prepare_configs() self.engine = PSFInferenceEngine( trained_model=self.trained_psf_model, batch_size=self.batch_size, output_dim=self.output_dim, ) # Convert inference data to tensorflow type self._convert_inference_data_to_tensorflow() # Get positions and SEDs positions = self.inference_data_adapter.complete_data["positions"] sed_data = self.inference_data_adapter.complete_data["seds"] return self.engine.compute_psfs(positions, sed_data)
def _ensure_psf_inference_completed(self): """Ensure that PSF inference has been completed. Runs inference if it has not been done yet. """ if self.engine is None or self.engine.inferred_psfs is None: self.run_inference()
[docs] def get_psfs(self): """Get all inferred PSFs. Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). Notes ----- Ensures automatically that inference has been completed before accessing the PSFs. """ self._ensure_psf_inference_completed() return self.engine.get_psfs()
[docs] def get_psf(self, index: int = 0) -> np.ndarray: """ Get the PSF at a specific index. Parameters ---------- index : int, optional Index of the PSF to retrieve (default is 0). Returns ------- numpy.ndarray The inferred PSF at the specified index with shape (output_dim, output_dim). Notes ----- Ensures automatically that inference has been completed before accessing the PSF. If only a single star was passed during instantiation, the index defaults to 0 and bounds checking is relaxed. """ self._ensure_psf_inference_completed() inferred_psfs = self.engine.get_psfs() # If a single-star batch, ignore index bounds if inferred_psfs.shape[0] == 1: return inferred_psfs[0] # Otherwise, return the PSF at the requested index return inferred_psfs[index]
[docs] def clear_cache(self): """ Clear all cached properties and reset the instance. This method resets all lazy-loaded properties, including the config handler, PSF simulator, data handler, trained model, and inference engine. Useful for freeing memory or forcing a fresh initialization. Notes ----- After calling this method, accessing any property will trigger re-initialization. """ self._config_handler = None self._simPSF = None self._data_adapter = None self._trained_psf_model = None self._n_bins_lambda = None self._batch_size = None self._cycle = None self._output_dim = None self.engine = None
[docs] class PSFInferenceEngine: """Engine to perform PSF inference using a trained model. This class handles the batch-wise computation of PSFs using a trained PSF model. It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access. Parameters ---------- trained_model : Model The trained PSF model to use for inference. batch_size : int The batch size for processing during inference. output_dim : int The output dimension (height and width) of the inferred PSFs. Attributes ---------- trained_model : Model The trained PSF model used for inference. batch_size : int The batch size for processing during inference. output_dim : int The output dimension (height and width) of the inferred PSFs. Examples -------- .. code-block:: python engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) psfs = engine.compute_psfs(positions, seds) single_psf = engine.get_psf(0) """ def __init__(self, trained_model, batch_size: int, output_dim: int): self.trained_model = trained_model self.batch_size = batch_size self.output_dim = output_dim self._inferred_psfs = None @property def inferred_psfs(self) -> np.ndarray: """Access the cached inferred PSFs, if available. Returns ------- numpy.ndarray or None The cached inferred PSFs, or None if not yet computed. """ return self._inferred_psfs
[docs] def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: """Compute and cache PSFs for the input source parameters. Parameters ---------- positions : tf.Tensor Tensor of shape (n_samples, 2) containing the (x, y) positions sed_data : tf.Tensor Tensor of shape (n_samples, n_bins, 2) containing the SEDs Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). Notes ----- PSFs are computed in batches according to the specified batch_size. Results are cached internally for subsequent access via get_psfs() or get_psf(). """ n_samples = positions.shape[0] self._inferred_psfs = np.zeros( (n_samples, self.output_dim, self.output_dim), dtype=np.float32 ) # Initialize counter counter = 0 while counter < n_samples: # Calculate the batch end element end_sample = min(counter + self.batch_size, n_samples) # Define the batch positions batch_pos = positions[counter:end_sample, :] batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter counter = end_sample return self._inferred_psfs
[docs] def get_psfs(self) -> np.ndarray: """Get all the generated PSFs. Returns ------- numpy.ndarray Array of inferred PSFs with shape (n_samples, output_dim, output_dim). """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs
[docs] def get_psf(self, index: int) -> np.ndarray: """Get the PSF at a specific index. Returns ------- numpy.ndarray The inferred PSF at the specified index with shape (output_dim, output_dim). Raises ------ ValueError If PSFs have not yet been computed. """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index]
[docs] def clear_cache(self): """ Clear cached inferred PSFs. Resets the internal PSF cache to free memory. After calling this method, compute_psfs() must be called again before accessing PSFs. """ self._inferred_psfs = None