Overview
Otter Streams integrates with Flink's AsyncDataStream API. You wrap a
configured InferenceEngine inside AsyncModelInferenceFunction and
hand it to Flink, which handles the thread pool, back-pressure, timeout, and ordered /
unordered result delivery. Your code only needs to supply the engine and the feature
extraction logic.
DataStream<T> input
│
├─ feature extraction (you supply a MapFunction or inline lambda)
│
└─▶ AsyncDataStream.unorderedWait(
inputStream,
AsyncModelInferenceFunction, ← Otter Streams
timeoutMs,
TimeUnit.MILLISECONDS,
maxConcurrentRequests
)
│
└─▶ DataStream<InferenceResult>
│
└─▶ your downstream operators (filter, map, sink, …)
InferenceEngine instances loaded
via the DataStream path are stored in the same ModelCache singleton as those
loaded by the SQL connector. If both paths are active in the same JVM process they share
the cache — models are loaded only once.
AsyncModelInferenceFunction
AsyncModelInferenceFunction<IN, OUT> extends Flink's
RichAsyncFunction<IN, OUT>. It manages the engine lifecycle — calling
initialize() in open() and close() in
close() — so you never need to handle that manually.
Constructor signature
// Generic constructor — supply the engine factory lambda
AsyncModelInferenceFunction<IN, OUT> fn = new AsyncModelInferenceFunction<>(
InferenceConfig config, // model + runtime settings
Function<InferenceConfig, InferenceEngine<?>> engineFactory, // engine factory
Function<IN, Map<String, Object>> featureExtractor, // input → feature map
Function<InferenceResult, OUT> resultTransformer // result → output type
);
Applying to a stream
DataStream<MyOutput> predictions = AsyncDataStream.unorderedWait(
inputStream,
inferenceFunction,
5_000L, // per-record timeout in milliseconds
TimeUnit.MILLISECONDS,
100 // max records in flight concurrently
);
Use orderedWait instead of unorderedWait if downstream operators
require records to arrive in input order. Ordered mode has slightly lower throughput because
it must buffer completed futures until any earlier in-flight records also complete.
ONNX Engine — OnnxInferenceEngine
The ONNX engine wraps Microsoft ONNX Runtime 1.23.2. It supports CPU, CUDA, and TensorRT
execution providers and handles float32 / int64 tensor construction automatically from a
Map<String, Object>.
Maven dependency
<dependency> <groupId>com.codedstreams</groupId> <artifactId>otter-stream-onnx</artifactId> <version>1.0.17</version> </dependency>
Complete DataStream example
import com.codedstreams.otterstream.inference.config.InferenceConfig;
import com.codedstreams.otterstream.inference.config.ModelConfig;
import com.codedstreams.otterstream.inference.config.ModelFormat;
import com.codedstreams.otterstream.inference.function.AsyncModelInferenceFunction;
import com.codedstreams.otterstream.inference.model.InferenceResult;
import com.codedstreams.otterstream.onnx.OnnxInferenceEngine;
import org.apache.flink.streaming.api.datastream.AsyncDataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.TimeUnit;
public class OnnxFraudDetectionJob {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// ── 1. Model configuration ───────────────────────────────────────────
ModelConfig modelConfig = ModelConfig.builder()
.modelId("fraud-detector")
.modelPath("s3a://ml-models/fraud-detector/v1/fraud.onnx")
// Execution providers tried in order; falls back to CPU automatically
.modelOption("providers", "CUDAExecutionProvider,CPUExecutionProvider")
.modelOption("intra_op_num_threads", "4")
.format(ModelFormat.ONNX)
.modelVersion("v1")
.build();
// ── 2. Runtime / inference configuration ─────────────────────────────
InferenceConfig inferenceConfig = InferenceConfig.builder()
.modelConfig(modelConfig)
.batchSize(32)
.timeout(Duration.ofSeconds(5))
.maxRetries(3)
.enableCaching(true)
.cacheSize(50_000)
.cacheTtl(Duration.ofMinutes(60))
.enableMetrics(true)
.metricsPrefix("fraud.onnx")
.build();
// ── 3. Build the async inference function ─────────────────────────────
AsyncModelInferenceFunction<Transaction, ScoredTransaction> fn =
new AsyncModelInferenceFunction<>(
inferenceConfig,
// Engine factory — called once per task slot in open()
cfg -> new OnnxInferenceEngine(),
// Feature extractor — Transaction → Map<String, Object>
tx -> Map.of(
"amount", (float) tx.getAmount(),
"hour_of_day", (float) tx.getHourOfDay(),
"day_of_week", (float) tx.getDayOfWeek(),
"is_international", tx.isInternational() ? 1.0f : 0.0f,
"card_type_enc", (float) tx.getCardTypeEncoded()
),
// Result transformer — InferenceResult → ScoredTransaction
result -> {
float score = ((float[]) result.getOutputs().get("output"))[0];
return new ScoredTransaction(
tx.getTransactionId(),
score,
score >= 0.85 ? "CRITICAL"
: score >= 0.65 ? "HIGH"
: score >= 0.40 ? "MEDIUM" : "LOW"
);
}
);
// ── 4. Source → async inference → sink ───────────────────────────────
DataStream<Transaction> transactions = env
.addSource(new KafkaTransactionSource())
.name("kafka-transactions");
DataStream<ScoredTransaction> scored = AsyncDataStream.unorderedWait(
transactions,
fn,
5_000L,
TimeUnit.MILLISECONDS,
100 // max concurrent async inference calls
).name("onnx-inference");
scored
.filter(s -> s.getFraudScore() >= 0.40f)
.addSink(new KafkaFraudAlertSink())
.name("fraud-alerts-sink");
env.execute("ONNX Fraud Detection Pipeline");
}
}
TensorFlow Engine — TensorFlowInferenceEngine
Loads a TensorFlow 2.x SavedModel directory. Signature names and tensor
names are discovered automatically from saved_model.pb at initialisation time.
GPU acceleration is available when the GPU-flavour TensorFlow native library is present on
the classpath.
Maven dependency
<dependency> <groupId>com.codedstreams</groupId> <artifactId>otter-stream-tensorflow</artifactId> <version>1.0.17</version> </dependency>
DataStream example
import com.codedstreams.otterstream.tensorflow.TensorFlowInferenceEngine;
ModelConfig tfConfig = ModelConfig.builder()
.modelId("image-classifier")
// Path to the SavedModel DIRECTORY (not the .pb file directly)
.modelPath("/mnt/models/resnet50/saved_model/")
.format(ModelFormat.TENSORFLOW_SAVEDMODEL)
// Optional: specify the serving signature name (default: "serving_default")
.modelOption("signature_name", "serving_default")
.modelVersion("v2")
.build();
InferenceConfig tfInferenceConfig = InferenceConfig.builder()
.modelConfig(tfConfig)
.batchSize(16)
.timeout(Duration.ofSeconds(10))
.enableMetrics(true)
.metricsPrefix("image.tf")
.build();
AsyncModelInferenceFunction<ImageEvent, ClassificationResult> tfFn =
new AsyncModelInferenceFunction<>(
tfInferenceConfig,
// Engine factory
cfg -> {
TensorFlowInferenceEngine engine = new TensorFlowInferenceEngine();
// Optionally read cached tensor names after init
// List<String> inputs = engine.getCachedInputNames();
// List<String> outputs = engine.getCachedOutputNames();
return engine;
},
// Feature extractor: flatten image pixels to float[]
img -> Map.of(
"input_1", img.getPixelsAsFloat32() // float[224*224*3]
),
// Result transformer
result -> {
float[] logits = (float[]) result.getOutputs().get("predictions");
int classIndex = argmax(logits);
return new ClassificationResult(img.getId(), classIndex, logits[classIndex]);
}
);
DataStream<ClassificationResult> classified = AsyncDataStream.unorderedWait(
imageStream, tfFn, 10_000L, TimeUnit.MILLISECONDS, 50
).name("tf-image-classification");
// Downstream: write top-k classifications to Elasticsearch
classified
.filter(r -> r.getConfidence() >= 0.80f)
.addSink(new ElasticsearchClassificationSink());
tensorflow-core-platform-linux-x86_64 or
tensorflow-core-platform-linux-gpu-x86_64.
PyTorch Engine — TorchScriptInferenceEngine
Loads TorchScript models (torch.jit.trace / torch.jit.script)
via Deep Java Library 0.25.0. Automatic GPU detection via CUDA. Tensor memory is managed
through DJL NDManager scopes that are closed after each inference call.
Maven dependency
<dependency> <groupId>com.codedstreams</groupId> <artifactId>otter-stream-pytorch</artifactId> <version>1.0.17</version> </dependency>
Export model in Python first
import torch
model.eval()
# Tracing — use for models with static control flow
example = torch.randn(1, 128) # batch=1, features=128
traced = torch.jit.trace(model, example)
traced.save("sentiment.pt")
# Scripting — use for models with dynamic control flow (if/for)
scripted = torch.jit.script(model)
scripted.save("sentiment_scripted.pt")
DataStream example
import com.codedstreams.otterstream.pytorch.TorchScriptInferenceEngine;
ModelConfig ptConfig = ModelConfig.builder()
.modelId("sentiment-classifier")
.modelPath("s3a://ml-models/nlp/sentiment.pt")
.format(ModelFormat.PYTORCH_TORCHSCRIPT)
// DJL engine auto-detects CUDA; override with "cpu" if needed
.modelOption("device", "auto")
.build();
InferenceConfig ptInferenceConfig = InferenceConfig.builder()
.modelConfig(ptConfig)
.batchSize(64)
.timeout(Duration.ofSeconds(8))
.enableMetrics(true)
.metricsPrefix("sentiment.pytorch")
.build();
AsyncModelInferenceFunction<ReviewEvent, SentimentResult> ptFn =
new AsyncModelInferenceFunction<>(
ptInferenceConfig,
// Engine factory — TorchScript engine via DJL
cfg -> new TorchScriptInferenceEngine(),
// Feature extractor — tokenised review → float[] embedding
review -> Map.of(
"input_ids", review.getTokenIds(), // int[]
"attention_mask", review.getAttentionMask() // int[]
),
// Result transformer — logits[2]: [neg, pos]
result -> {
float[] logits = (float[]) result.getOutputs().get("logits");
float posScore = softmax(logits)[1];
return new SentimentResult(
review.getReviewId(),
posScore >= 0.5f ? "POSITIVE" : "NEGATIVE",
posScore
);
}
);
DataStream<SentimentResult> sentiments = AsyncDataStream.unorderedWait(
reviewStream, ptFn, 8_000L, TimeUnit.MILLISECONDS, 80
).name("pytorch-sentiment");
sentiments.addSink(new BigQuerySentimentSink());
XGBoost Engine — XGBoostInferenceEngine
Wraps XGBoost4J 3.1.1. Converts the feature map to a DMatrix per call;
thread-safe because Booster.predict() is re-entrant in XGBoost4J.
Supports binary classification, multi-class, and regression objectives.
Maven dependency
<dependency> <groupId>com.codedstreams</groupId> <artifactId>otter-streams-xgboost</artifactId> <version>1.0.17</version> </dependency>
DataStream example
import com.codedstreams.otterstream.xgboost.XGBoostInferenceEngine;
// ── Model supports three file formats ───────────────────────────────────────
// ModelFormat.XGBOOST_BINARY → .bin / .model (fastest loading)
// ModelFormat.XGBOOST_JSON → .json (portable, readable)
// ModelFormat.XGBOOST_UBJSON → .ubj (compact binary JSON)
ModelConfig xgbConfig = ModelConfig.builder()
.modelId("credit-scorer")
.modelPath("s3a://ml-models/credit/credit_v3.json")
.format(ModelFormat.XGBOOST_JSON)
.modelVersion("v3")
.build();
InferenceConfig xgbInferenceConfig = InferenceConfig.builder()
.modelConfig(xgbConfig)
.batchSize(128) // XGBoost handles large batches efficiently
.timeout(Duration.ofSeconds(3))
.enableCaching(true)
.cacheSize(100_000)
.cacheTtl(Duration.ofMinutes(30))
.enableMetrics(true)
.metricsPrefix("credit.xgboost")
.build();
AsyncModelInferenceFunction<CreditApplication, CreditDecision> xgbFn =
new AsyncModelInferenceFunction<>(
xgbInferenceConfig,
// Engine factory
cfg -> new XGBoostInferenceEngine(),
// Feature extractor — tabular features as float values
// XGBoost expects all features in a flat float array;
// the engine maps feature names to column positions using the model's
// feature_names metadata. Use Float.NaN for missing values.
app -> Map.of(
"age", (float) app.getAge(),
"annual_income", (float) app.getAnnualIncome(),
"credit_score", (float) app.getCreditScore(),
"debt_to_income", (float) app.getDebtToIncomeRatio(),
"employment_years", (float) app.getEmploymentYears(),
"num_open_accounts",(float) app.getNumOpenAccounts(),
// Missing value → Float.NaN; XGBoost handles natively
"previous_defaults", app.hasPreviousDefault() ? 1.0f : 0.0f
),
// Result transformer — binary:logistic output is a single float [0, 1]
result -> {
float probability = ((float[]) result.getOutputs().get("prediction"))[0];
String decision = probability >= 0.70f ? "DECLINE"
: probability >= 0.40f ? "REVIEW"
: "APPROVE";
return new CreditDecision(
app.getApplicationId(), decision, probability
);
}
);
DataStream<CreditDecision> decisions = AsyncDataStream.unorderedWait(
applicationStream, xgbFn, 3_000L, TimeUnit.MILLISECONDS, 200
).name("xgboost-credit-scoring");
decisions
.keyBy(CreditDecision::getDecision)
.addSink(new DatabaseDecisionSink());
PMML Engine — PmmlInferenceEngine
Evaluates PMML 4.x documents via JPMML-Evaluator 1.5.16. All built-in PMML pre-processing
transforms (normalization, discretization, outlier treatment) are applied automatically.
Thread-safe via immutable Evaluator instances.
Maven dependency
<dependency> <groupId>com.codedstreams</groupId> <artifactId>otter-stream-pmml</artifactId> <version>1.0.17</version> </dependency>
DataStream example
import com.codedstreams.otterstream.pmml.PmmlInferenceEngine;
ModelConfig pmmlConfig = ModelConfig.builder()
.modelId("churn-predictor")
.modelPath("/mnt/models/churn/churn_logistic.pmml")
.format(ModelFormat.PMML)
.modelVersion("v1")
.build();
InferenceConfig pmmlInferenceConfig = InferenceConfig.builder()
.modelConfig(pmmlConfig)
.batchSize(1) // PMML evaluator processes one record at a time internally
.timeout(Duration.ofSeconds(5))
.enableMetrics(true)
.metricsPrefix("churn.pmml")
.build();
AsyncModelInferenceFunction<CustomerEvent, ChurnPrediction> pmmlFn =
new AsyncModelInferenceFunction<>(
pmmlInferenceConfig,
// Engine factory
cfg -> new PmmlInferenceEngine(),
// Feature extractor — PMML uses string-keyed field names matching the
// DataDictionary in the .pmml file exactly (case-sensitive).
// Values can be String, Integer, Double — PMML handles coercion.
customer -> Map.of(
"tenure_months", customer.getTenureMonths(),
"monthly_charges", customer.getMonthlyCharges(),
"total_charges", customer.getTotalCharges(),
"contract_type", customer.getContractType(), // "Month-to-month"
"payment_method", customer.getPaymentMethod(), // "Electronic check"
"internet_service", customer.getInternetService(), // "Fiber optic"
"num_support_tickets", customer.getSupportTickets()
),
// Result transformer — PMML outputs a Map of output field names
result -> {
// PMML typically outputs both the predicted class and probabilities
String churnClass = (String) result.getOutputs().get("predictedChurn");
Double churnProb = (Double) result.getOutputs().get("probability_Yes");
return new ChurnPrediction(
customer.getCustomerId(),
"Yes".equals(churnClass),
churnProb != null ? churnProb.floatValue() : 0.0f
);
}
);
DataStream<ChurnPrediction> predictions = AsyncDataStream.unorderedWait(
customerStream, pmmlFn, 5_000L, TimeUnit.MILLISECONDS, 50
).name("pmml-churn-prediction");
predictions
.filter(ChurnPrediction::isChurning)
.addSink(new CRMRetentionSink());
Remote Engine — HttpInferenceEngine / SageMakerInferenceClient
Routes inference calls to external endpoints — REST APIs, AWS SageMaker, Google Vertex AI, Azure ML, or any gRPC model server. Uses OkHttp for HTTP and Netty for gRPC. Configurable retry, timeout, and circuit-breaker policies. Ideal when the model is too large to embed in the Flink JAR, or when a dedicated model server is required.
Maven dependency
<dependency> <groupId>com.codedstreams</groupId> <artifactId>otter-stream-remote</artifactId> <version>1.0.17</version> </dependency>
HTTP REST endpoint
import com.codedstreams.otterstream.remote.HttpInferenceEngine;
import com.codedstreams.otterstream.inference.config.AuthConfig;
ModelConfig httpConfig = ModelConfig.builder()
.modelId("external-model")
.endpointUrl("https://api.mymodelservice.com/v1/predict")
.format(ModelFormat.REMOTE_HTTP)
.authConfig(AuthConfig.builder()
.addHeader("Authorization", "Bearer " + System.getenv("MODEL_API_TOKEN"))
.addHeader("X-Api-Version", "2024-01")
.build())
// Request/response content type
.modelOption("content_type", "application/json")
// Connection pool settings
.modelOption("max_connections", "50")
.modelOption("connect_timeout_ms", "2000")
.build();
InferenceConfig httpInferenceConfig = InferenceConfig.builder()
.modelConfig(httpConfig)
.timeout(Duration.ofSeconds(8))
.maxRetries(3)
.retryDelay(Duration.ofMillis(200))
.enableMetrics(true)
.metricsPrefix("external.http")
.build();
AsyncModelInferenceFunction<InputEvent, ScoredEvent> httpFn =
new AsyncModelInferenceFunction<>(
httpInferenceConfig,
cfg -> new HttpInferenceEngine(),
event -> Map.of(
"feature_a", event.getFeatureA(),
"feature_b", event.getFeatureB()
),
result -> new ScoredEvent(
event.getId(),
(Float) result.getOutputs().get("score")
)
);
DataStream<ScoredEvent> scored = AsyncDataStream.unorderedWait(
inputStream, httpFn, 8_000L, TimeUnit.MILLISECONDS, 30
).name("http-remote-inference");
AWS SageMaker endpoint
import com.codedstreams.otterstream.remote.SageMakerInferenceClient;
ModelConfig sagemakerConfig = ModelConfig.builder()
.modelId("sagemaker-fraud")
// SageMaker endpoint name (not the full ARN)
.endpointUrl("fraud-detection-prod-endpoint")
.format(ModelFormat.SAGEMAKER)
.authConfig(AuthConfig.builder()
// Access key:secret key pair, or leave null to use instance IAM role
.apiKey(System.getenv("AWS_ACCESS_KEY_ID")
+ ":" + System.getenv("AWS_SECRET_ACCESS_KEY"))
.build())
.modelOption("region", "us-east-1")
.modelOption("content_type", "application/json")
.build();
InferenceConfig smInferenceConfig = InferenceConfig.builder()
.modelConfig(sagemakerConfig)
.timeout(Duration.ofSeconds(10))
.maxRetries(3)
.enableMetrics(true)
.metricsPrefix("fraud.sagemaker")
.build();
AsyncModelInferenceFunction<Transaction, ScoredTransaction> smFn =
new AsyncModelInferenceFunction<>(
smInferenceConfig,
cfg -> new SageMakerInferenceClient(),
tx -> Map.of("amount", tx.getAmount(), "merchant", tx.getMerchantId()),
result -> new ScoredTransaction(tx, (Float) result.getOutputs().get("score"))
);
DataStream<ScoredTransaction> smScored = AsyncDataStream.unorderedWait(
transactionStream, smFn, 10_000L, TimeUnit.MILLISECONDS, 50
).name("sagemaker-inference");
ModelCache Integration in DataStream Jobs
For jobs that need to hot-swap models or pre-warm engines before the streaming job starts,
you can interact with ModelCache directly. This is optional — the
AsyncModelInferenceFunction manages the cache automatically, but direct access
is useful for custom model management logic.
import com.codedstreams.otterstreams.sql.loader.ModelCache;
import com.codedstreams.otterstream.inference.engine.InferenceEngine;
// ── Pre-warm the cache before the streaming job starts ──────────────────────
ModelCache cache = ModelCache.getInstance();
// Build and initialise an engine, then register it by name
OnnxInferenceEngine engine = new OnnxInferenceEngine();
engine.initialize(modelConfig);
cache.putEngine("fraud-detector", engine);
// ── Check if a model is already loaded ─────────────────────────────────────
InferenceEngine<?> existing = cache.getEngine("fraud-detector");
if (existing == null) {
// Model not loaded — trigger load
}
// ── Hot-swap a model version at runtime ────────────────────────────────────
// 1. Initialise new version engine
OnnxInferenceEngine v2Engine = new OnnxInferenceEngine();
v2Engine.initialize(v2Config);
// 2. Atomically replace (ModelCache.putEngine is thread-safe)
cache.putEngine("fraud-detector", v2Engine);
// Old engine will be garbage-collected; its close() is called on eviction.
// ── Invalidate a model (force reload on next inference) ────────────────────
cache.invalidate("fraud-detector");
// ── Clear all cached models (e.g., before shutdown) ────────────────────────
cache.invalidateAll();
MinIO Model Loading in DataStream Jobs
In the DataStream path, model loading from MinIO is controlled through the
ModelConfig.modelPath URI. When the AsyncModelInferenceFunction
calls engine.initialize(config) in open(), the engine
calls the MinioModelLoader internally to download the model file into a
local temp directory before loading it into memory. Set the MinIO credentials as Flink
configuration properties or as environment variables.
// ── Pass MinIO credentials via Flink configuration ─────────────────────────
Configuration flinkConf = new Configuration();
flinkConf.setString("otter.minio.endpoint", "http://minio:9000");
flinkConf.setString("otter.minio.access-key", System.getenv("MINIO_ACCESS_KEY"));
flinkConf.setString("otter.minio.secret-key", System.getenv("MINIO_SECRET_KEY"));
// path-style is always required for MinIO (virtual-hosted style needs DNS config)
flinkConf.setString("otter.minio.path-style", "true");
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment(flinkConf);
// ── ModelConfig references the s3a:// URI directly ─────────────────────────
ModelConfig minioConfig = ModelConfig.builder()
.modelId("iot-anomaly-detector")
// The loader strips s3a:// and uses the Minio endpoint from Flink config
.modelPath("s3a://ml-models/anomaly-detector/v2/model.json")
.format(ModelFormat.XGBOOST_JSON)
.modelVersion("v2")
.build();
InferenceConfig minioInferenceConfig = InferenceConfig.builder()
.modelConfig(minioConfig)
.batchSize(64)
.timeout(Duration.ofSeconds(5))
.enableCaching(true)
.cacheTtl(Duration.ofMinutes(60))
.enableMetrics(true)
.metricsPrefix("iot.xgboost.minio")
.build();
AsyncModelInferenceFunction<SensorReading, AnomalyScore> fn =
new AsyncModelInferenceFunction<>(
minioInferenceConfig,
cfg -> new XGBoostInferenceEngine(), // MinIO download happens inside initialize()
reading -> Map.of(
"temperature", (float) reading.getTemperature(),
"pressure", (float) reading.getPressure(),
"vibration", (float) reading.getVibration(),
"rpm", (float) reading.getRpm()
),
result -> new AnomalyScore(
reading.getDeviceId(),
((float[]) result.getOutputs().get("prediction"))[0]
)
);
DataStream<AnomalyScore> anomalies = AsyncDataStream.unorderedWait(
sensorStream, fn, 5_000L, TimeUnit.MILLISECONDS, 100
).name("xgboost-minio-anomaly");
Full DataStream Pipeline Example
A complete, self-contained Flink job showing real-world patterns: multi-source Kafka ingestion, feature engineering window, ONNX inference, risk tiering, dual sinks, and dead-letter handling.
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.streaming.api.datastream.*;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
import org.apache.flink.streaming.api.windowing.time.Time;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.TimeUnit;
public class FullFraudPipelineJob {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(8);
env.enableCheckpointing(30_000L); // checkpoint every 30 s
// ── 1. Source: raw transactions from Kafka ───────────────────────────
DataStream<Transaction> rawTxns = env
.fromSource(kafkaSource, WatermarkStrategy
.<Transaction>forBoundedOutOfOrderness(Duration.ofSeconds(5))
.withTimestampAssigner((tx, ts) -> tx.getEventTimeMs()),
"kafka-transactions")
.name("kafka-source");
// ── 2. Feature engineering: compute 5-min velocity per user ──────────
DataStream<EnrichedTransaction> enriched = rawTxns
.keyBy(Transaction::getUserId)
.window(TumblingEventTimeWindows.of(Time.minutes(5)))
.process(new VelocityFeatureProcess()) // adds txn_count, total_amount_5m
.name("velocity-features");
// ── 3. ONNX async inference ──────────────────────────────────────────
ModelConfig modelConfig = ModelConfig.builder()
.modelId("fraud-v3")
.modelPath("s3a://ml-models/fraud/v3/fraud.onnx")
.format(ModelFormat.ONNX)
.modelOption("providers", "CUDAExecutionProvider,CPUExecutionProvider")
.build();
InferenceConfig inferenceConfig = InferenceConfig.builder()
.modelConfig(modelConfig)
.batchSize(32)
.timeout(Duration.ofSeconds(5))
.maxRetries(2)
.enableCaching(true)
.cacheSize(50_000)
.cacheTtl(Duration.ofMinutes(60))
.enableMetrics(true)
.metricsPrefix("fraud.v3.onnx")
.build();
AsyncModelInferenceFunction<EnrichedTransaction, ScoredTransaction> fn =
new AsyncModelInferenceFunction<>(
inferenceConfig,
cfg -> new OnnxInferenceEngine(),
tx -> Map.of(
"amount", (float) tx.getAmount(),
"hour_of_day", (float) tx.getHourOfDay(),
"txn_count_5m", (float) tx.getTxnCount5m(),
"total_amount_5m", (float) tx.getTotalAmount5m(),
"is_international", tx.isInternational() ? 1.0f : 0.0f
),
result -> {
if (!result.isSuccess()) {
// Propagate failed inference as a sentinel score of -1
return ScoredTransaction.failed(tx);
}
float score = ((float[]) result.getOutputs().get("output"))[0];
return ScoredTransaction.of(tx, score);
}
);
DataStream<ScoredTransaction> scored = AsyncDataStream.unorderedWait(
enriched, fn, 5_000L, TimeUnit.MILLISECONDS, 100
).name("onnx-inference");
// ── 4. Split: alerts vs dead-letters (failed inference) ──────────────
OutputTag<ScoredTransaction> deadLetterTag =
new OutputTag<>("dead-letter"){};
SingleOutputStreamOperator<ScoredTransaction> alerts = scored
.process(new ProcessFunction<ScoredTransaction, ScoredTransaction>() {
@Override
public void processElement(ScoredTransaction s,
Context ctx,
Collector<ScoredTransaction> out) {
if (s.isFailed()) {
ctx.output(deadLetterTag, s);
} else if (s.getFraudScore() >= 0.40f) {
out.collect(s);
}
}
}).name("fraud-router");
DataStream<ScoredTransaction> deadLetters = alerts.getSideOutput(deadLetterTag);
// ── 5. Sinks ─────────────────────────────────────────────────────────
alerts
.map(s -> FraudAlert.fromScored(s))
.addSink(new KafkaFraudAlertSink())
.name("fraud-alert-kafka-sink");
deadLetters
.addSink(new S3DeadLetterSink("s3://fraud-dead-letters/"))
.name("dead-letter-s3-sink");
env.execute("Otter Streams — Full Fraud Detection DataStream Pipeline");
}
}