Class TensorFlowInferenceEngine

  • All Implemented Interfaces:
    InferenceEngine<org.tensorflow.SavedModelBundle>

    public class TensorFlowInferenceEngine
    extends LocalInferenceEngine<org.tensorflow.SavedModelBundle>
    TensorFlow SavedModel inference engine using TensorFlow Java API.

    This engine provides inference capabilities for TensorFlow models saved in the SavedModel format. It leverages the official TensorFlow Java API to load and execute models with support for both CPU and GPU execution (when TensorFlow is built with GPU support).

    Supported TensorFlow Features:

    • SavedModel Format: TensorFlow's standard serialization format
    • Signature Parsing: Automatic extraction of input/output signatures
    • Tensor Types: Float32 and Int32 tensors with multi-dimensional support
    • Batch Inference: Native support through tensor shape manipulation
    • GPU Acceleration: Automatic when TensorFlow Java with GPU support

    Model Loading:

    
     ModelConfig config = ModelConfig.builder()
         .modelPath("/path/to/saved_model")  // Directory containing saved_model.pb
         .modelId("tensorflow-model")
         .modelVersion("v1")
         .build();
    
     TensorFlowInferenceEngine engine = new TensorFlowInferenceEngine();
     engine.initialize(config);
    
     // Get metadata including input/output names
     ModelMetadata metadata = engine.getMetadata();
     List<String> inputNames = engine.getCachedInputNames();
     

    Inference Example:

    
     Map<String, Object> inputs = new HashMap<>();
     inputs.put("input_1", new float[]{0.1f, 0.2f, 0.3f});
     inputs.put("input_2", new int[]{1, 2, 3});
    
     InferenceResult result = engine.infer(inputs);
     float[] predictions = (float[]) result.getOutput("predictions");
     

    Tensor Creation:

    Automatically creates appropriate TensorFlow tensors from Java types:

    Java TypeTensorFlow TypeShape
    float[]TFloat32[1, array_length]
    float[][]TFloat32[rows, cols]
    int[]TInt32[1, array_length]

    Signature Discovery:

    The engine automatically discovers model signatures:

    1. First tries "serving_default" signature
    2. Falls back to first available signature
    3. Extracts input/output tensor names and shapes
    4. Caches names for performance

    Capabilities:

    FeatureSupportedNotes
    Batch InferenceYesThrough tensor shape manipulation
    Native BatchingYesTensorFlow native batch support
    Max Batch Size128Configurable based on memory
    GPU SupportYesWhen TensorFlow Java GPU version used
    Multi-threadingYesSession.Runner supports concurrent inference

    Dependencies:

     Requires TensorFlow Java API:
     - org.tensorflow:tensorflow-core-platform (runtime)
     - org.tensorflow:tensorflow-core-api (runtime)
     - For GPU: org.tensorflow:libtensorflow with GPU support
     

    Performance Features:

    • Signature Caching: Input/output names cached for performance
    • Tensor Reuse: Automatic tensor cleanup to prevent memory leaks
    • Direct Memory Access: Uses TensorFlow's native memory management
    • Session Pooling: SavedModelBundle manages session resources

    Thread Safety:

    Session.Runner is not thread-safe, but SavedModelBundle can be used from multiple threads by creating separate runners. Consider:

    • Creating separate runners per thread
    • Using Session with synchronization
    • Implementing connection pooling for high-throughput scenarios

    Resource Management:

    Always call close() to release native TensorFlow resources. TensorFlow uses native memory that must be explicitly released:

    
     try (TensorFlowInferenceEngine engine = new TensorFlowInferenceEngine()) {
         engine.initialize(config);
         InferenceResult result = engine.infer(inputs);
     }
     
    Since:
    1.0.0
    Author:
    Nestor Martourez, Sr Software and Data Streaming Engineer @ CodedStreams
    See Also:
    LocalInferenceEngine, SavedModelBundle, TensorFlow Java API
    • Constructor Detail

      • TensorFlowInferenceEngine

        public TensorFlowInferenceEngine()
    • Method Detail

      • initialize

        public void initialize​(ModelConfig config)
                        throws InferenceException
        Initializes the TensorFlow inference engine by loading a SavedModel.

        The initialization process:

        1. Loads SavedModel from directory using SavedModelBundle.load(java.lang.String, java.lang.String...)
        2. Parses model signature to extract input/output tensor information
        3. Caches input and output names for performance
        4. Validates model readiness for inference

        SavedModel Structure:

         saved_model_directory/
           ├── saved_model.pb      # Model graph and signatures
           ├── variables/          # Model weights
           └── assets/             # Additional files (optional)
         

        Signature Discovery:

        The engine looks for signatures in this order:

        1. "serving_default" signature (standard for serving)
        2. First available signature in the model
        3. Fallback to common output names if no signature found
        Specified by:
        initialize in interface InferenceEngine<org.tensorflow.SavedModelBundle>
        Overrides:
        initialize in class LocalInferenceEngine<org.tensorflow.SavedModelBundle>
        Parameters:
        config - model configuration containing SavedModel directory path
        Throws:
        InferenceException - if model loading fails or SavedModel is invalid
      • infer

        public InferenceResult infer​(Map<String,​Object> inputs)
                              throws InferenceException
        Performs single inference using TensorFlow SavedModel.

        The inference process:

        1. Creates Session.Runner for inference execution
        2. Converts input values to TensorFlow tensors
        3. Feeds tensors to runner using cached input names
        4. Fetches output tensors using cached output names
        5. Executes graph and extracts results
        6. Cleans up tensors to prevent memory leaks

        Tensor Management:

        All created tensors are automatically closed using try-with-resources pattern. Output tensors from the runner are also explicitly closed in the finally block to prevent native memory leaks.

        Error Handling:

        Specified by:
        infer in interface InferenceEngine<org.tensorflow.SavedModelBundle>
        Specified by:
        infer in class LocalInferenceEngine<org.tensorflow.SavedModelBundle>
        Parameters:
        inputs - map of input tensor names to values (float[] or int[])
        Returns:
        inference result containing outputs and timing information
        Throws:
        InferenceException - if inference fails or tensor creation fails
      • inferBatch

        public InferenceResult inferBatch​(Map<String,​Object>[] batchInputs)
                                   throws InferenceException
        Performs batch inference (simplified implementation).

        Note: Current implementation processes only the first input in the batch. For proper batch inference, extend this method to:

        1. Create batch tensors by stacking individual inputs
        2. Modify tensor shapes to include batch dimension
        3. Execute single inference with batch tensor
        4. Split batch output into individual results
        Specified by:
        inferBatch in interface InferenceEngine<org.tensorflow.SavedModelBundle>
        Specified by:
        inferBatch in class LocalInferenceEngine<org.tensorflow.SavedModelBundle>
        Parameters:
        batchInputs - array of input maps for batch processing
        Returns:
        inference result for first input (placeholder implementation)
        Throws:
        InferenceException - if batch processing fails
      • getCapabilities

        public InferenceEngine.EngineCapabilities getCapabilities()
        Gets the engine's capabilities for TensorFlow inference.

        TensorFlow engine capabilities:

        • Batch Inference: Supported through tensor batching
        • Native Batching: Yes, TensorFlow native batch support
        • Max Batch Size: 128 (conservative default for memory safety)
        • GPU Support: Yes, when using TensorFlow GPU version
        Specified by:
        getCapabilities in interface InferenceEngine<org.tensorflow.SavedModelBundle>
        Specified by:
        getCapabilities in class LocalInferenceEngine<org.tensorflow.SavedModelBundle>
        Returns:
        engine capabilities indicating full TensorFlow support
      • getMetadata

        public ModelMetadata getMetadata()
        Gets metadata about the loaded TensorFlow model.

        Extracts metadata from the SavedModel including:

        • Model name and version from configuration
        • Input/output schema from cached names
        • Model format as ModelFormat.TENSORFLOW_SAVEDMODEL
        • Load timestamp for freshness tracking

        Schema Format:

        Inputs and outputs are stored as maps with indexed keys:

         Input Schema:  {"input_0": "input_tensor_name", "input_1": ...}
         Output Schema: {"output_0": "output_tensor_name", "output_1": ...}
         
        Returns:
        comprehensive model metadata
      • close

        public void close()
                   throws InferenceException
        Closes the TensorFlow engine and releases native resources.

        Closes the SavedModelBundle which releases:

        • Graph definition memory
        • Session resources
        • Variable storage
        • Any GPU memory allocated

        Always call this method when finished to prevent native memory leaks.

        Specified by:
        close in interface InferenceEngine<org.tensorflow.SavedModelBundle>
        Overrides:
        close in class LocalInferenceEngine<org.tensorflow.SavedModelBundle>
        Throws:
        InferenceException - if resource cleanup fails
      • getCachedInputNames

        public List<String> getCachedInputNames()
        Gets the cached input tensor names.
        Returns:
        list of input tensor names extracted during initialization
      • getCachedOutputNames

        public List<String> getCachedOutputNames()
        Gets the cached output tensor names.
        Returns:
        list of output tensor names extracted during initialization