wf_psf.psf_models.tf_modules.tf_utils

TensorFlow Utilities Module.

Provides lightweight utility functions for safely converting and managing data types within TensorFlow-based workflows.

Includes: - ensure_tensor: ensures inputs are TensorFlow tensors with specified dtype

These tools are designed to support PSF model components, including lazy property evaluation, data input validation, and type normalization.

This module is intended for internal use in model layers and inference components to enforce TensorFlow-compatible inputs.

Authors: Jennifer Pollack <jennifer.pollack@cea.fr>

Functions

ensure_tensor(input_array[, dtype])

Ensure the input is a TensorFlow tensor of the specified dtype.

find_position_indices(obs_pos, batch_positions)

Find indices of batch positions within observed positions using vectorized operations.

wf_psf.psf_models.tf_modules.tf_utils.ensure_tensor(input_array, dtype=tf.float32)[source]

Ensure the input is a TensorFlow tensor of the specified dtype.

Parameters:
  • input_array (array-like, tf.Tensor, or np.ndarray) – The input to convert.

  • dtype (tf.DType, optional) – The desired TensorFlow dtype (default: tf.float32).

Returns:

A TensorFlow tensor with the specified dtype.

Return type:

tf.Tensor

wf_psf.psf_models.tf_modules.tf_utils.find_position_indices(obs_pos, batch_positions)[source]

Find indices of batch positions within observed positions using vectorized operations.

This function locates the indices of multiple query positions within a reference set of observed positions using broadcasting and vectorized operations. Each position in the batch must have an exact match in the observed positions.

Parameters:
  • obs_pos (tf.Tensor) – Reference positions tensor of shape (n_obs, 2), where n_obs is the number of observed positions. Each row contains [x, y] coordinates.

  • batch_positions (tf.Tensor) – Query positions tensor of shape (batch_size, 2), where batch_size is the number of positions to look up. Each row contains [x, y] coordinates.

Returns:

indices – Tensor of shape (batch_size,) containing the indices of each batch position within obs_pos. The dtype is tf.int64.

Return type:

tf.Tensor

Raises:

tf.errors.InvalidArgumentError – If any position in batch_positions is not found in obs_pos.

Notes

Uses exact equality matching - positions must match exactly. More efficient than iterative lookups for multiple positions due to vectorized operations.