otter-stream-pytorch

TorchScript

PyTorch TorchScript integration via Deep Java Library (DJL) with automatic GPU detection.

Module Overview

The PyTorch module enables TorchScript model inference using DJL's PyTorch engine with dynamic graph support. It provides flexible input/output translation and efficient memory management for PyTorch models in streaming applications.

🔷

PyTorch/DJL Engine

TorchScript

TorchScript model inference using DJL's PyTorch engine with dynamic graph support.

  • TorchScript models (torch.jit.trace/script)
  • Automatic GPU detection via CUDA
  • DJL NDArray integration for tensor operations
  • Efficient memory management with NDManager
🎨

Translator System

Data Processing

Flexible input/output translation between Java objects and PyTorch tensors.

  • MapTranslator for Map-based I/O
  • Automatic batch dimension handling
  • Float and int array conversions
  • Extensible for custom data types

Implementing PyTorch Inference

Guide to using PyTorch models exported as TorchScript.

  1. Add Maven Dependency
    <dependency>
        <groupId>com.codedstreams</groupId>
        <artifactId>otter-stream-pytorch</artifactId>
        <version>1.0.16</version>
    </dependency>
  2. Export PyTorch Model to TorchScript
    import torch
    
    # Method 1: Tracing (for static graphs)
    example_input = torch.randn(1, 3, 224, 224)
    traced_model = torch.jit.trace(model, example_input)
    traced_model.save("model.pt")
    
    # Method 2: Scripting (for dynamic graphs)
    scripted_model = torch.jit.script(model)
    scripted_model.save("model.pt")
  3. Configure and Use Engine
    ModelConfig config = ModelConfig.builder()
        .modelPath("model.pt")
        .modelId("pytorch-model")
        .format(ModelFormat.PYTORCH_TORCHSCRIPT)
        .build();
    
    TorchScriptInferenceEngine engine = new TorchScriptInferenceEngine();
    engine.initialize(config);
    
    Map<String, Object> inputs = Map.of(
        "input1", new float[]{0.1f, 0.2f, 0.3f},
        "input2", new int[]{1, 2, 3}
    );
    
    InferenceResult result = engine.infer(inputs);

Maven Dependency

<dependency>
    <groupId>com.codedstreams</groupId>
    <artifactId>otter-stream-pytorch</artifactId>
    <version>1.0.16</version>
</dependency>