wf_psf.inference.psf_inference

Inference.

A module which provides a PSFInference class to perform inference with trained PSF models. It is able to load a trained model, perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs.

Authors:

Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobias.liaudat@cea.fr>

Classes

InferenceConfigHandler(inference_config_path)

Handle configuration loading and management for PSF inference.

PSFInference(inference_config_path[, ...])

Perform PSF inference using a pre-trained WaveDiff model.

PSFInferenceEngine(trained_model, ...)

Engine to perform PSF inference using a trained model.

class wf_psf.inference.psf_inference.InferenceConfigHandler(inference_config_path: str)[source]

Bases: object

Handle configuration loading and management for PSF inference.

This class manages the loading of inference, training, and data configuration files required for PSF inference operations.

Parameters:

inference_config_path (str) – Path to the inference configuration YAML file.

inference_config_path

Path to the inference configuration file.

Type:

str

inference_config

Loaded inference configuration.

Type:

RecursiveNamespace or None

training_config

Loaded training configuration.

Type:

RecursiveNamespace or None

data_config

Loaded data configuration.

Type:

RecursiveNamespace or None

trained_model_path

Path to the trained model directory.

Type:

Path

model_subdir

Subdirectory name for model files.

Type:

str

trained_model_config_path

Path to the training configuration file.

Type:

Path

data_config_path

Path to the data configuration file.

Type:

str or None

Methods

load_configs()

Load configuration files based on the inference config.

overwrite_model_params([training_config, ...])

Overwrite training model_params with values from inference_config if available.

set_config_paths()

Extract and set the configuration paths from the inference config.

ids = ('inference_conf',)
load_configs()[source]

Load configuration files based on the inference config.

Loads the inference configuration first, then uses it to determine and load the training and data configurations.

Notes

Updates the following attributes in-place: - inference_config - training_config - data_config (if data_config_path is specified)

static overwrite_model_params(training_config=None, inference_config=None)[source]

Overwrite training model_params with values from inference_config if available.

Parameters:

Notes

Updates are applied in-place to training_config.training.model_params.

set_config_paths()[source]

Extract and set the configuration paths from the inference config.

Sets the following attributes: - trained_model_path - model_subdir - trained_model_config_path - data_config_path

class wf_psf.inference.psf_inference.PSFInference(inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None)[source]

Bases: object

Perform PSF inference using a pre-trained WaveDiff model.

This class handles the setup for PSF inference, including loading configuration files, instantiating the PSF simulator and data handler, and preparing the input data required for inference.

Parameters:
  • inference_config_path (str) – Path to the inference configuration YAML file.

  • x_field (array-like, optional) – x coordinates in SHE convention.

  • y_field (array-like, optional) – y coordinates in SHE convention.

  • seds (array-like, optional) – Spectral energy distributions (SEDs).

  • sources (array-like, optional) – Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]).

  • masks (array-like, optional) – Corresponding masks for the sources (same shape as sources). Defaults to None.

inference_config_path

Path to the inference configuration file.

Type:

str

x_field

x coordinates for PSF positions.

Type:

array-like or None

y_field

y coordinates for PSF positions.

Type:

array-like or None

seds

Spectral energy distributions.

Type:

array-like or None

sources

Source postage stamps.

Type:

array-like or None

masks

Source masks.

Type:

array-like or None

engine

The inference engine instance.

Type:

PSFInferenceEngine or None

Examples

Basic usage with position coordinates and SEDs:

psf_inf = PSFInference(
    inference_config_path="config.yaml",
    x_field=[100.5, 200.3],
    y_field=[150.2, 250.8],
    seds=sed_array
)
psf_inf.run_inference()
psf = psf_inf.get_psf(0)
Attributes:
batch_size

Get the batch size for inference.

config_handler

Get or create the configuration handler.

cycle

Get the cycle number for inference.

data_config

Get the data configuration.

data_handler

Get or create the data handler.

inference_config

Get the inference configuration.

n_bins_lambda

Get the number of wavelength bins for inference.

output_dim

Get the output dimension for PSF inference.

simPSF

Get or create the PSF simulator.

trained_psf_model

Get or load the trained PSF model.

training_config

Get the training configuration.

Methods

clear_cache()

Clear all cached properties and reset the instance.

get_positions()

Combine x_field and y_field into position pairs.

get_psf([index])

Get the PSF at a specific index.

get_psfs()

Get all inferred PSFs.

load_inference_model()

Load the trained PSF model based on the inference configuration.

prepare_configs()

Prepare the configuration for inference.

run_inference()

Run PSF inference and return the full PSF array.

property batch_size

Get the batch size for inference.

Returns:

The batch size for processing during inference.

Return type:

int

clear_cache()[source]

Clear all cached properties and reset the instance.

This method resets all lazy-loaded properties, including the config handler, PSF simulator, data handler, trained model, and inference engine. Useful for freeing memory or forcing a fresh initialization.

Notes

After calling this method, accessing any property will trigger re-initialization.

property config_handler

Get or create the configuration handler.

Returns:

The configuration handler instance with loaded configs.

Return type:

InferenceConfigHandler

property cycle

Get the cycle number for inference.

Returns:

The cycle number used for loading the trained model.

Return type:

int

property data_config

Get the data configuration.

Returns:

The data configuration object, or None if not available.

Return type:

RecursiveNamespace or None

property data_handler

Get or create the data handler.

Returns:

The data handler instance configured for inference.

Return type:

DataHandler

get_positions()[source]

Combine x_field and y_field into position pairs.

Returns:

Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None.

Return type:

numpy.ndarray

Raises:

ValueError – If x_field and y_field have different lengths.

get_psf(index: int = 0) ndarray[source]

Get the PSF at a specific index.

Parameters:

index (int, optional) – Index of the PSF to retrieve (default is 0).

Returns:

The inferred PSF at the specified index with shape (output_dim, output_dim).

Return type:

numpy.ndarray

Notes

Ensures automatically that inference has been completed before accessing the PSF. If only a single star was passed during instantiation, the index defaults to 0 and bounds checking is relaxed.

get_psfs()[source]

Get all inferred PSFs.

Returns:

Array of inferred PSFs with shape (n_samples, output_dim, output_dim).

Return type:

numpy.ndarray

Notes

Ensures automatically that inference has been completed before accessing the PSFs.

property inference_config

Get the inference configuration.

Returns:

The inference configuration object.

Return type:

RecursiveNamespace

load_inference_model()[source]

Load the trained PSF model based on the inference configuration.

Returns:

The loaded trained PSF model.

Return type:

Model

Notes

Constructs the weights path pattern based on the trained model path, model subdirectory, model name, id name, and cycle number specified in the configuration files.

property n_bins_lambda

Get the number of wavelength bins for inference.

Returns:

The number of wavelength bins used during inference.

Return type:

int

property output_dim

Get the output dimension for PSF inference.

Returns:

The output dimension (height and width) of the inferred PSFs.

Return type:

int

prepare_configs()[source]

Prepare the configuration for inference.

Overwrites training model parameters with inference configuration values.

run_inference()[source]

Run PSF inference and return the full PSF array.

Returns:

Array of inferred PSFs with shape (n_samples, output_dim, output_dim).

Return type:

numpy.ndarray

Notes

Prepares configurations and input data, initializes the inference engine, and computes the PSF for all input positions.

property simPSF

Get or create the PSF simulator.

Returns:

The PSF simulator instance.

Return type:

simPSF

property trained_psf_model

Get or load the trained PSF model.

Returns:

The loaded trained PSF model.

Return type:

Model

property training_config

Get the training configuration.

Returns:

The training configuration object.

Return type:

RecursiveNamespace

class wf_psf.inference.psf_inference.PSFInferenceEngine(trained_model, batch_size: int, output_dim: int)[source]

Bases: object

Engine to perform PSF inference using a trained model.

This class handles the batch-wise computation of PSFs using a trained PSF model. It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access.

Parameters:
  • trained_model (Model) – The trained PSF model to use for inference.

  • batch_size (int) – The batch size for processing during inference.

  • output_dim (int) – The output dimension (height and width) of the inferred PSFs.

trained_model

The trained PSF model used for inference.

Type:

Model

batch_size

The batch size for processing during inference.

Type:

int

output_dim

The output dimension (height and width) of the inferred PSFs.

Type:

int

Examples

>>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64)
>>> psfs = engine.compute_psfs(positions, seds)
>>> single_psf = engine.get_psf(0)
Attributes:
inferred_psfs

Access the cached inferred PSFs, if available.

Methods

clear_cache()

Clear cached inferred PSFs.

compute_psfs(positions, sed_data)

Compute and cache PSFs for the input source parameters.

get_psf(index)

Get the PSF at a specific index.

get_psfs()

Get all the generated PSFs.

clear_cache()[source]

Clear cached inferred PSFs.

Resets the internal PSF cache to free memory. After calling this method, compute_psfs() must be called again before accessing PSFs.

compute_psfs(positions: Tensor, sed_data: Tensor) ndarray[source]

Compute and cache PSFs for the input source parameters.

Parameters:
  • positions (tf.Tensor) – Tensor of shape (n_samples, 2) containing the (x, y) positions

  • sed_data (tf.Tensor) – Tensor of shape (n_samples, n_bins, 2) containing the SEDs

Returns:

Array of inferred PSFs with shape (n_samples, output_dim, output_dim).

Return type:

numpy.ndarray

Notes

PSFs are computed in batches according to the specified batch_size. Results are cached internally for subsequent access via get_psfs() or get_psf().

get_psf(index: int) ndarray[source]

Get the PSF at a specific index.

Returns:

  • numpy.ndarray

  • The inferred PSF at the specified index with shape (output_dim, output_dim).

Raises:

ValueError – If PSFs have not yet been computed.

get_psfs() ndarray[source]

Get all the generated PSFs.

Returns:

Array of inferred PSFs with shape (n_samples, output_dim, output_dim).

Return type:

numpy.ndarray

property inferred_psfs: ndarray

Access the cached inferred PSFs, if available.

Returns:

The cached inferred PSFs, or None if not yet computed.

Return type:

numpy.ndarray or None