Class XGBoostInferenceEngine
- java.lang.Object
-
- com.codedstream.otterstream.inference.engine.LocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>
-
- com.codedstream.otterstream.xgboost.XGBoostInferenceEngine
-
- All Implemented Interfaces:
InferenceEngine<ml.dmlc.xgboost4j.java.Booster>
public class XGBoostInferenceEngine extends LocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>
XGBoost inference engine for gradient boosting tree models.This engine provides inference capabilities for XGBoost models using the XGBoost4J Java library. XGBoost is an optimized distributed gradient boosting library designed for efficiency, flexibility, and portability, widely used for tabular data and structured data problems.
Supported XGBoost Features:
- Model Formats: XGBoost binary (.model), JSON (.json), UBJSON (.ubj)
- Task Types: Regression, binary classification, multi-class classification
- Batch Inference: Efficient batch prediction through matrix operations
- Missing Values: Native support for NaN as missing value indicator
- Thread Safety: Model predictions are thread-safe
Model Loading:
ModelConfig config = ModelConfig.builder() .modelPath("model.xgb") // XGBoost model file .modelId("xgboost-model") .build(); XGBoostInferenceEngine engine = new XGBoostInferenceEngine(); engine.initialize(config);Inference Example:
Map<String, Object> inputs = new HashMap<>(); inputs.put("age", 35.0f); inputs.put("income", 75000.0f); inputs.put("credit_score", 720.0f); inputs.put("loan_amount", 25000.0f); InferenceResult result = engine.infer(inputs); // For regression/binary classification float prediction = (float) result.getOutput("prediction"); // For multi-class classification float[] probabilities = (float[]) result.getOutput("probabilities"); int predictedClass = (int) result.getOutput("prediction");Feature Extraction:
The engine assumes input features are already in the correct order for the XGBoost model. Features are extracted in the order they appear in the input Map. For production use, implement feature ordering based on model metadata or configuration.
Prediction Outputs:
Task Type Output Format Example Regression Single float value {"prediction": 0.75} Binary Classification Single probability {"prediction": 0.92} Multi-class Classification Probabilities array + class index {"probabilities": [0.1,0.8,0.1], "prediction": 1} Capabilities:
Feature Supported Notes Batch Inference Yes Efficient matrix-based batch processing Native Batching No Batch size limited by memory Max Batch Size 1000 Conservative default for memory safety GPU Support Yes When XGBoost built with GPU support Missing Values Yes NaN represents missing values Dependencies:
Requires XGBoost4J Java library: - ml.dmlc:xgboost4j (runtime) - ml.dmlc:xgboost4j-linux-gpu (optional, for GPU support)
Performance Features:
- DMatrix Optimization: Efficient column-major data storage
- Thread Pool: XGBoost uses internal thread pool for prediction
- Memory Efficient: Automatic DMatrix disposal to prevent leaks
- Batch Processing: Significant speedup for batch predictions
Thread Safety:
Boosterprediction methods are thread-safe according to XGBoost documentation. Multiple threads can call {@link Booster# predict} concurrently. However,DMatrixcreation and disposal should be synchronized if sharing matrices between threads.Resource Management:
XGBoost uses native memory through
DMatrixandBooster. Always callclose()to release native resources:try (XGBoostInferenceEngine engine = new XGBoostInferenceEngine()) { engine.initialize(config); InferenceResult result = engine.infer(inputs); }- Since:
- 1.0.0
- Author:
- Nestor Martourez, Sr Software and Data Streaming Engineer @ CodedStreams
- See Also:
LocalInferenceEngine,Booster,DMatrix, XGBoost Documentation
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface com.codedstream.otterstream.inference.engine.InferenceEngine
InferenceEngine.EngineCapabilities
-
-
Field Summary
-
Fields inherited from class com.codedstream.otterstream.inference.engine.LocalInferenceEngine
initialized, loadedModel, modelConfig, modelLoader
-
-
Constructor Summary
Constructors Constructor Description XGBoostInferenceEngine()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidclose()Closes the XGBoost engine and releases native resources.InferenceEngine.EngineCapabilitiesgetCapabilities()Gets the engine's capabilities for XGBoost inference.ModelMetadatagetMetadata()Gets metadata about the loaded XGBoost model.InferenceResultinfer(Map<String,Object> inputs)Performs single inference using XGBoost model.InferenceResultinferBatch(Map<String,Object>[] batchInputs)Performs batch inference using XGBoost's efficient matrix operations.voidinitialize(ModelConfig config)Initializes the XGBoost inference engine by loading a model file.-
Methods inherited from class com.codedstream.otterstream.inference.engine.LocalInferenceEngine
getModelConfig, isReady, loadModelDirectly
-
-
-
-
Method Detail
-
initialize
public void initialize(ModelConfig config) throws InferenceException
Initializes the XGBoost inference engine by loading a model file.Supports various XGBoost model formats:
- Binary: .model (default binary format)
- JSON: .json (human-readable, larger file size)
- UBJSON: .ubj (binary JSON, efficient storage)
GPU Support:
If XGBoost is compiled with GPU support and a GPU is available, predictions will automatically use GPU acceleration. Check XGBoost documentation for GPU compilation instructions.
- Specified by:
initializein interfaceInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Overrides:
initializein classLocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Parameters:
config- model configuration containing XGBoost model file path- Throws:
InferenceException- if model loading fails or file is invalid
-
infer
public InferenceResult infer(Map<String,Object> inputs) throws InferenceException
Performs single inference using XGBoost model.The inference process:
- Extracts features from input Map to float array
- Creates
DMatrixwith shape [1, num_features] - Calls {@link Booster# predict} for prediction
- Formats output based on prediction task type
- Disposes
DMatrixto prevent memory leaks
Output Format Detection:
Automatically detects prediction task type based on output shape:
- Single value: Regression or binary classification
- Multiple values: Multi-class classification probabilities
Missing Values:
XGBoost handles missing values represented as
Float.NaN. Features with NaN values are treated as missing during prediction.- Specified by:
inferin interfaceInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Specified by:
inferin classLocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Parameters:
inputs- map of feature names to values (Number or float[])- Returns:
- inference result with formatted predictions
- Throws:
InferenceException- if inference fails or feature extraction failsIllegalArgumentException- for unsupported feature types
-
inferBatch
public InferenceResult inferBatch(Map<String,Object>[] batchInputs) throws InferenceException
Performs batch inference using XGBoost's efficient matrix operations.Batch inference is significantly faster than sequential single predictions because XGBoost processes the entire batch matrix in native code. The method:
- Validates all batch inputs have same feature count
- Flattens batch features into single array
- Creates
DMatrixwith shape [batch_size, num_features] - Performs batch prediction in single native call
- Returns 2D array of predictions
Memory Efficiency:
The batch features array uses contiguous memory for efficient data transfer to native XGBoost library. For very large batches, consider splitting into smaller batches to manage memory usage.
Output Format:
{ "batch_predictions": [ [0.1, 0.9], // Sample 1 predictions [0.7, 0.3], // Sample 2 predictions ... ] }- Specified by:
inferBatchin interfaceInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Specified by:
inferBatchin classLocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Parameters:
batchInputs- array of input maps for batch processing- Returns:
- inference result containing 2D array of batch predictions
- Throws:
InferenceException- if batch inference fails or feature counts mismatch
-
getCapabilities
public InferenceEngine.EngineCapabilities getCapabilities()
Gets the engine's capabilities for XGBoost inference.XGBoost engine capabilities:
- Batch Inference: Yes, efficient matrix operations
- Native Batching: No, batch size limited by memory
- Max Batch Size: 1000 (conservative for memory safety)
- GPU Support: Yes, when XGBoost compiled with GPU support
- Specified by:
getCapabilitiesin interfaceInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Specified by:
getCapabilitiesin classLocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Returns:
- engine capabilities for XGBoost
-
close
public void close() throws InferenceExceptionCloses the XGBoost engine and releases native resources.Disposes the
Boosterwhich releases:- Tree structure memory
- Leaf values and split conditions
- Any GPU memory allocated
- Internal thread pool resources
Always call this method when finished to prevent native memory leaks.
- Specified by:
closein interfaceInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Overrides:
closein classLocalInferenceEngine<ml.dmlc.xgboost4j.java.Booster>- Throws:
InferenceException- if Booster disposal fails
-
getMetadata
public ModelMetadata getMetadata()
Gets metadata about the loaded XGBoost model.TODO: Implement XGBoost metadata extraction. XGBoost models contain metadata that can be extracted via:
Booster.getModelDump(java.lang.String, boolean)- Tree structure dumpBooster.getFeatureScore(java.lang.String[])- Feature importance scores- {@link Booster# attributes} - Model attributes (objective, booster type)
- Number of features and trees
- Returns:
- model metadata (currently returns null, override for implementation)
-
-