Source code for wf_psf.psf_models.tf_modules.tf_utils

"""TensorFlow Utilities Module.

Provides lightweight utility functions for safely converting and managing data types
within TensorFlow-based workflows.

Includes:
- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype

These tools are designed to support PSF model components, including lazy property evaluation,
data input validation, and type normalization.

This module is intended for internal use in model layers and inference components to enforce
TensorFlow-compatible inputs.

Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""

import tensorflow as tf


[docs] @tf.function def find_position_indices(obs_pos, batch_positions): """Find indices of batch positions within observed positions using vectorized operations. This function locates the indices of multiple query positions within a reference set of observed positions using broadcasting and vectorized operations. Each position in the batch must have an exact match in the observed positions. Parameters ---------- obs_pos : tf.Tensor Reference positions tensor of shape (n_obs, 2), where n_obs is the number of observed positions. Each row contains [x, y] coordinates. batch_positions : tf.Tensor Query positions tensor of shape (batch_size, 2), where batch_size is the number of positions to look up. Each row contains [x, y] coordinates. Returns ------- indices : tf.Tensor Tensor of shape (batch_size,) containing the indices of each batch position within obs_pos. The dtype is tf.int64. Raises ------ tf.errors.InvalidArgumentError If any position in batch_positions is not found in obs_pos. Notes ----- Uses exact equality matching - positions must match exactly. More efficient than iterative lookups for multiple positions due to vectorized operations. """ # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) obs_expanded = tf.expand_dims(obs_pos, 0) pos_expanded = tf.expand_dims(batch_positions, 1) # Compare all positions at once: (batch_size, n_obs) matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) # Find the index of the matching position for each batch item # argmax returns the first True value's index along axis=1 indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) # Verify all positions were found tf.debugging.assert_equal( tf.reduce_all(tf.reduce_any(matches, axis=1)), True, message="Some positions not found in obs_pos", ) return indices
[docs] def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. Parameters ---------- input_array : array-like, tf.Tensor, or np.ndarray The input to convert. dtype : tf.DType, optional The desired TensorFlow dtype (default: tf.float32). Returns ------- tf.Tensor A TensorFlow tensor with the specified dtype. """ if tf.is_tensor(input_array): # If already a tensor, optionally cast dtype if different if input_array.dtype != dtype: return tf.cast(input_array, dtype) return input_array else: # Convert numpy arrays or other types to tensor return tf.convert_to_tensor(input_array, dtype=dtype)