Class TorchScriptInferenceEngine
-
- All Implemented Interfaces:
InferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>
public class TorchScriptInferenceEngine extends LocalInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>
TorchScript (PyTorch) implementation ofLocalInferenceEngineusing Deep Java Library (DJL).This engine provides inference capabilities for PyTorch models saved in TorchScript format. It leverages DJL's PyTorch engine to load and execute models with automatic GPU acceleration when available. The engine handles PyTorch's dynamic graph execution and tensor operations.
Supported PyTorch Features:
- TorchScript Models: PyTorch models exported via
torch.jit.traceortorch.jit.script - Data Types: Float and integer tensors with automatic dimension handling
- Batch Dimension: Automatic addition of batch dimension via
expandDims(0) - GPU Acceleration: Automatic GPU detection and execution when available
Model Loading:
ModelConfig config = ModelConfig.builder() .modelPath("model.pt") // TorchScript model file .modelId("pytorch-model") .build(); TorchScriptInferenceEngine engine = new TorchScriptInferenceEngine(); engine.initialize(config);Inference Example:
Map<String, Object> inputs = new HashMap<>(); inputs.put("input1", new float[]{0.1f, 0.2f, 0.3f}); inputs.put("input2", new int[]{1, 2, 3}); InferenceResult result = engine.infer(inputs); float[] predictions = (float[]) result.getOutput("output_0");Input Processing:
The
TorchScriptInferenceEngine.MapTranslatorautomatically processes inputs:- float[]: Converts to FloatTensor with added batch dimension
- int[]: Converts to IntTensor with added batch dimension
- Dimension Expansion: Adds batch dimension via
expandDims(0)
Output Processing:
Outputs are automatically converted back to Java arrays:
- Tensor to Array: Converts DJL NDArrays to float arrays
- Named Outputs: Generates output names as "output_0", "output_1", etc.
- Type Preservation: Maintains original tensor data types
Capabilities:
Feature Supported Notes Batch Inference Yes Via batch dimension in tensors Native Batching Yes Through DJL's batch processing Max Batch Size 64 Configurable based on memory GPU Support Yes Automatic CUDA detection via DJL Dynamic Graphs Yes Supports TorchScript dynamic execution Dependencies:
Requires DJL PyTorch engine: - ai.djl:api (runtime) - ai.djl.pytorch:pytorch-engine (runtime) - ai.djl.pytorch:pytorch-native-auto (runtime)
Performance Features:
- Automatic GPU: DJL automatically uses GPU if CUDA is available
- Memory Management:
NDManagerfor efficient tensor lifecycle - Batch Optimization: Native batch processing through tensor operations
- Model Caching: DJL caches loaded models for repeated use
Thread Safety:
DJL
Predictorinstances are not thread-safe. For concurrent inference:- Create separate engine instances per thread
- Use
Predictorpooling for high-throughput scenarios - Synchronize access to
infer(java.util.Map<java.lang.String, java.lang.Object>)method
Resource Management:
Always call
close()to release native resources (GPU memory, file handles). The engine implementsAutoCloseablefor use with try-with-resources:try (TorchScriptInferenceEngine engine = new TorchScriptInferenceEngine()) { engine.initialize(config); InferenceResult result = engine.infer(inputs); }- Since:
- 1.0.0
- Author:
- Nestor Martourez, Sr Software and Data Streaming Engineer @ CodedStreams
- See Also:
LocalInferenceEngine,Predictor,ZooModel, Deep Java Library Documentation
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface com.codedstream.otterstream.inference.engine.InferenceEngine
InferenceEngine.EngineCapabilities
-
-
Field Summary
-
Fields inherited from class com.codedstream.otterstream.inference.engine.LocalInferenceEngine
initialized, loadedModel, modelConfig, modelLoader
-
-
Constructor Summary
Constructors Constructor Description TorchScriptInferenceEngine()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidclose()Closes the engine and releases all DJL and native resources.InferenceEngine.EngineCapabilitiesgetCapabilities()Gets the engine's capabilities for PyTorch inference.ModelMetadatagetMetadata()Gets metadata about the loaded PyTorch model.InferenceResultinfer(Map<String,Object> inputs)Performs single inference on the provided inputs using the PyTorch model.InferenceResultinferBatch(Map<String,Object>[] batchInputs)Batch inference implementation.voidinitialize(ModelConfig config)Initializes the PyTorch inference engine by loading a TorchScript model.-
Methods inherited from class com.codedstream.otterstream.inference.engine.LocalInferenceEngine
getModelConfig, isReady, loadModelDirectly
-
-
-
-
Method Detail
-
initialize
public void initialize(ModelConfig config) throws InferenceException
Initializes the PyTorch inference engine by loading a TorchScript model.The initialization process:
- Creates
NDManagerfor tensor memory management - Builds
Criteriafor model loading with Map-based I/O - Configures
TorchScriptInferenceEngine.MapTranslatorfor input/output processing - Loads model using DJL's
Criteria.loadModel() - Creates
Predictorfor inference execution
DJL Automatic Features:
- Engine Selection: Automatically selects PyTorch engine
- GPU Detection: Uses CUDA if available, falls back to CPU
- Native Libraries: Loads PyTorch native libraries automatically
- Specified by:
initializein interfaceInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Overrides:
initializein classLocalInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Parameters:
config- model configuration containing TorchScript model path- Throws:
InferenceException- if model loading fails or DJL is not properly configured
- Creates
-
infer
public InferenceResult infer(Map<String,Object> inputs) throws InferenceException
Performs single inference on the provided inputs using the PyTorch model.The inference process:
- Inputs are processed by
TorchScriptInferenceEngine.MapTranslator.processInput(ai.djl.translate.TranslatorContext, java.util.Map<java.lang.String, java.lang.Object>) - Converted to DJL
NDListwith batch dimensions - Executed through PyTorch engine (GPU if available)
- Outputs converted back to Map via
TorchScriptInferenceEngine.MapTranslator.processOutput(ai.djl.translate.TranslatorContext, ai.djl.ndarray.NDList)
Input Requirements:
- float[] arrays: Converted to Float32 tensors
- int[] arrays: Converted to Int32 tensors
- Batch Dimension: Automatically added (
expandDims(0))
Output Format:
Outputs are named sequentially as "output_0", "output_1", etc., containing float arrays extracted from output tensors.
- Specified by:
inferin interfaceInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Specified by:
inferin classLocalInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Parameters:
inputs- map of input names to arrays (float[] or int[])- Returns:
- inference result containing predictions and timing
- Throws:
InferenceException- if inference fails or inputs are invalid
- Inputs are processed by
-
inferBatch
public InferenceResult inferBatch(Map<String,Object>[] batchInputs) throws InferenceException
Batch inference implementation.TODO: Implement batch inference for PyTorch models. Potential approaches:
- Stack individual tensors into batch tensors
- Use DJL's batch predictor capabilities
- Implement custom batch translator
- Leverage PyTorch's native batch processing
- Specified by:
inferBatchin interfaceInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Specified by:
inferBatchin classLocalInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Parameters:
batchInputs- array of input maps for batch processing- Returns:
- batch inference results (currently returns null)
- Throws:
InferenceException- not currently implemented
-
close
public void close() throws InferenceExceptionCloses the engine and releases all DJL and native resources.Releases resources in reverse initialization order:
Predictor: Stops inference execution threadsZooModel: Unloads PyTorch model from memoryNDManager: Releases all tensor memory (GPU/CPU)- Calls parent cleanup
-
getMetadata
public ModelMetadata getMetadata()
Gets metadata about the loaded PyTorch model.TODO: Implement PyTorch metadata extraction via DJL. Potential metadata includes:
- Model architecture information
- Input/output tensor shapes and types
- GPU/CPU execution mode
- PyTorch version and model format
- Returns:
- model metadata (currently returns null, override for implementation)
-
getCapabilities
public InferenceEngine.EngineCapabilities getCapabilities()
Gets the engine's capabilities for PyTorch inference.PyTorch engine capabilities:
- Batch Inference: Supported through tensor batching
- Native Batching: Yes, via DJL's batch processing
- Max Batch Size: 64 (conservative default)
- GPU Support: Yes, automatic CUDA detection
- Specified by:
getCapabilitiesin interfaceInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Specified by:
getCapabilitiesin classLocalInferenceEngine<ai.djl.repository.zoo.ZooModel<Map<String,Object>,Map<String,Object>>>- Returns:
- engine capabilities indicating full PyTorch support
-
-