Class TensorFlowInferenceEngine
- java.lang.Object
-
- com.codedstream.otterstream.inference.engine.LocalInferenceEngine<org.tensorflow.SavedModelBundle>
-
- com.codedstream.otterstream.tensorflow.TensorFlowInferenceEngine
-
- All Implemented Interfaces:
InferenceEngine<org.tensorflow.SavedModelBundle>
public class TensorFlowInferenceEngine extends LocalInferenceEngine<org.tensorflow.SavedModelBundle>
TensorFlow SavedModel inference engine using TensorFlow Java API.This engine provides inference capabilities for TensorFlow models saved in the SavedModel format. It leverages the official TensorFlow Java API to load and execute models with support for both CPU and GPU execution (when TensorFlow is built with GPU support).
Supported TensorFlow Features:
- SavedModel Format: TensorFlow's standard serialization format
- Signature Parsing: Automatic extraction of input/output signatures
- Tensor Types: Float32 and Int32 tensors with multi-dimensional support
- Batch Inference: Native support through tensor shape manipulation
- GPU Acceleration: Automatic when TensorFlow Java with GPU support
Model Loading:
ModelConfig config = ModelConfig.builder() .modelPath("/path/to/saved_model") // Directory containing saved_model.pb .modelId("tensorflow-model") .modelVersion("v1") .build(); TensorFlowInferenceEngine engine = new TensorFlowInferenceEngine(); engine.initialize(config); // Get metadata including input/output names ModelMetadata metadata = engine.getMetadata(); List<String> inputNames = engine.getCachedInputNames();Inference Example:
Map<String, Object> inputs = new HashMap<>(); inputs.put("input_1", new float[]{0.1f, 0.2f, 0.3f}); inputs.put("input_2", new int[]{1, 2, 3}); InferenceResult result = engine.infer(inputs); float[] predictions = (float[]) result.getOutput("predictions");Tensor Creation:
Automatically creates appropriate TensorFlow tensors from Java types:
Java Type TensorFlow Type Shape float[] TFloat32 [1, array_length] float[][] TFloat32 [rows, cols] int[] TInt32 [1, array_length] Signature Discovery:
The engine automatically discovers model signatures:
- First tries "serving_default" signature
- Falls back to first available signature
- Extracts input/output tensor names and shapes
- Caches names for performance
Capabilities:
Feature Supported Notes Batch Inference Yes Through tensor shape manipulation Native Batching Yes TensorFlow native batch support Max Batch Size 128 Configurable based on memory GPU Support Yes When TensorFlow Java GPU version used Multi-threading Yes Session.Runner supports concurrent inference Dependencies:
Requires TensorFlow Java API: - org.tensorflow:tensorflow-core-platform (runtime) - org.tensorflow:tensorflow-core-api (runtime) - For GPU: org.tensorflow:libtensorflow with GPU support
Performance Features:
- Signature Caching: Input/output names cached for performance
- Tensor Reuse: Automatic tensor cleanup to prevent memory leaks
- Direct Memory Access: Uses TensorFlow's native memory management
- Session Pooling:
SavedModelBundlemanages session resources
Thread Safety:
Session.Runneris not thread-safe, butSavedModelBundlecan be used from multiple threads by creating separate runners. Consider:- Creating separate runners per thread
- Using
Sessionwith synchronization - Implementing connection pooling for high-throughput scenarios
Resource Management:
Always call
close()to release native TensorFlow resources. TensorFlow uses native memory that must be explicitly released:try (TensorFlowInferenceEngine engine = new TensorFlowInferenceEngine()) { engine.initialize(config); InferenceResult result = engine.infer(inputs); }- Since:
- 1.0.0
- Author:
- Nestor Martourez, Sr Software and Data Streaming Engineer @ CodedStreams
- See Also:
LocalInferenceEngine,SavedModelBundle, TensorFlow Java API
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface com.codedstream.otterstream.inference.engine.InferenceEngine
InferenceEngine.EngineCapabilities
-
-
Field Summary
-
Fields inherited from class com.codedstream.otterstream.inference.engine.LocalInferenceEngine
initialized, loadedModel, modelConfig, modelLoader
-
-
Constructor Summary
Constructors Constructor Description TensorFlowInferenceEngine()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidclose()Closes the TensorFlow engine and releases native resources.List<String>getCachedInputNames()Gets the cached input tensor names.List<String>getCachedOutputNames()Gets the cached output tensor names.InferenceEngine.EngineCapabilitiesgetCapabilities()Gets the engine's capabilities for TensorFlow inference.ModelMetadatagetMetadata()Gets metadata about the loaded TensorFlow model.InferenceResultinfer(Map<String,Object> inputs)Performs single inference using TensorFlow SavedModel.InferenceResultinferBatch(Map<String,Object>[] batchInputs)Performs batch inference (simplified implementation).voidinitialize(ModelConfig config)Initializes the TensorFlow inference engine by loading a SavedModel.-
Methods inherited from class com.codedstream.otterstream.inference.engine.LocalInferenceEngine
getModelConfig, isReady, loadModelDirectly
-
-
-
-
Method Detail
-
initialize
public void initialize(ModelConfig config) throws InferenceException
Initializes the TensorFlow inference engine by loading a SavedModel.The initialization process:
- Loads SavedModel from directory using
SavedModelBundle.load(java.lang.String, java.lang.String...) - Parses model signature to extract input/output tensor information
- Caches input and output names for performance
- Validates model readiness for inference
SavedModel Structure:
saved_model_directory/ ├── saved_model.pb # Model graph and signatures ├── variables/ # Model weights └── assets/ # Additional files (optional)
Signature Discovery:
The engine looks for signatures in this order:
- "serving_default" signature (standard for serving)
- First available signature in the model
- Fallback to common output names if no signature found
- Specified by:
initializein interfaceInferenceEngine<org.tensorflow.SavedModelBundle>- Overrides:
initializein classLocalInferenceEngine<org.tensorflow.SavedModelBundle>- Parameters:
config- model configuration containing SavedModel directory path- Throws:
InferenceException- if model loading fails or SavedModel is invalid
- Loads SavedModel from directory using
-
infer
public InferenceResult infer(Map<String,Object> inputs) throws InferenceException
Performs single inference using TensorFlow SavedModel.The inference process:
- Creates
Session.Runnerfor inference execution - Converts input values to TensorFlow tensors
- Feeds tensors to runner using cached input names
- Fetches output tensors using cached output names
- Executes graph and extracts results
- Cleans up tensors to prevent memory leaks
Tensor Management:
All created tensors are automatically closed using try-with-resources pattern. Output tensors from the runner are also explicitly closed in the finally block to prevent native memory leaks.
Error Handling:
- Invalid tensor types throw
IllegalArgumentException - Missing input names result in runtime errors
- TensorFlow runtime errors throw
InferenceException
- Specified by:
inferin interfaceInferenceEngine<org.tensorflow.SavedModelBundle>- Specified by:
inferin classLocalInferenceEngine<org.tensorflow.SavedModelBundle>- Parameters:
inputs- map of input tensor names to values (float[] or int[])- Returns:
- inference result containing outputs and timing information
- Throws:
InferenceException- if inference fails or tensor creation fails
- Creates
-
inferBatch
public InferenceResult inferBatch(Map<String,Object>[] batchInputs) throws InferenceException
Performs batch inference (simplified implementation).Note: Current implementation processes only the first input in the batch. For proper batch inference, extend this method to:
- Create batch tensors by stacking individual inputs
- Modify tensor shapes to include batch dimension
- Execute single inference with batch tensor
- Split batch output into individual results
- Specified by:
inferBatchin interfaceInferenceEngine<org.tensorflow.SavedModelBundle>- Specified by:
inferBatchin classLocalInferenceEngine<org.tensorflow.SavedModelBundle>- Parameters:
batchInputs- array of input maps for batch processing- Returns:
- inference result for first input (placeholder implementation)
- Throws:
InferenceException- if batch processing fails
-
getCapabilities
public InferenceEngine.EngineCapabilities getCapabilities()
Gets the engine's capabilities for TensorFlow inference.TensorFlow engine capabilities:
- Batch Inference: Supported through tensor batching
- Native Batching: Yes, TensorFlow native batch support
- Max Batch Size: 128 (conservative default for memory safety)
- GPU Support: Yes, when using TensorFlow GPU version
- Specified by:
getCapabilitiesin interfaceInferenceEngine<org.tensorflow.SavedModelBundle>- Specified by:
getCapabilitiesin classLocalInferenceEngine<org.tensorflow.SavedModelBundle>- Returns:
- engine capabilities indicating full TensorFlow support
-
getMetadata
public ModelMetadata getMetadata()
Gets metadata about the loaded TensorFlow model.Extracts metadata from the SavedModel including:
- Model name and version from configuration
- Input/output schema from cached names
- Model format as
ModelFormat.TENSORFLOW_SAVEDMODEL - Load timestamp for freshness tracking
Schema Format:
Inputs and outputs are stored as maps with indexed keys:
Input Schema: {"input_0": "input_tensor_name", "input_1": ...} Output Schema: {"output_0": "output_tensor_name", "output_1": ...}- Returns:
- comprehensive model metadata
-
close
public void close() throws InferenceExceptionCloses the TensorFlow engine and releases native resources.Closes the
SavedModelBundlewhich releases:- Graph definition memory
- Session resources
- Variable storage
- Any GPU memory allocated
Always call this method when finished to prevent native memory leaks.
- Specified by:
closein interfaceInferenceEngine<org.tensorflow.SavedModelBundle>- Overrides:
closein classLocalInferenceEngine<org.tensorflow.SavedModelBundle>- Throws:
InferenceException- if resource cleanup fails
-
getCachedInputNames
public List<String> getCachedInputNames()
Gets the cached input tensor names.- Returns:
- list of input tensor names extracted during initialization
-
-