Class SageMakerInferenceClient

  • All Implemented Interfaces:
    InferenceEngine<Void>

    public class SageMakerInferenceClient
    extends RemoteInferenceEngine
    AWS SageMaker remote inference client for hosted ML models.

    This engine provides integration with AWS SageMaker endpoints for inference on models hosted in the AWS cloud. It uses the AWS SDK for Java v2 to communicate with SageMaker Runtime API, supporting both static credentials and AWS IAM roles.

    Supported Features:

    • SageMaker Endpoints: Integration with deployed SageMaker model endpoints
    • AWS Authentication: Static credentials or IAM role-based authentication
    • Region Configuration: Configurable AWS regions (default: us-east-1)
    • JSON Payloads: Automatic serialization of inputs to SageMaker format
    • Connection Validation: Test inference with ping request

    Configuration Example:

    
     ModelConfig config = ModelConfig.builder()
         .modelId("sagemaker-model")
         .endpointUrl("my-sagemaker-endpoint") // SageMaker endpoint name
         .authConfig(AuthConfig.builder()
             .apiKey("ACCESS_KEY:SECRET_KEY") // Optional static credentials
             .build())
         .build();
    
     SageMakerInferenceClient client = new SageMakerInferenceClient();
     client.initialize(config);
     

    Authentication Options:

    1. Static Credentials: Provide ACCESS_KEY:SECRET_KEY in authConfig.apiKey
    2. IAM Role: Omit credentials to use AWS IAM role (EC2, ECS, Lambda)
    3. Profile: Use AWS profile from ~/.aws/credentials

    SageMaker Request Format:

     POST /endpoints/{endpoint-name}/invocations
     Content-Type: application/json
    
     {
       "feature1": value1,
       "feature2": value2,
       ...
     }
     

    Error Handling:

    AWS SDK Integration:

    Uses AWS SDK for Java v2 with automatic retry logic, request compression, and connection pooling. The SDK handles:

    • Request signing with AWS Signature Version 4
    • Automatic retry with exponential backoff
    • Connection management and pooling
    • Request/response logging (when configured)

    Cost Considerations:

    • SageMaker charges per inference hour + data transfer
    • Consider batch inference to reduce cost per prediction
    • Use appropriate instance types for cost-performance balance

    Thread Safety:

    SageMakerRuntimeClient is thread-safe and can be shared across threads. The client uses connection pooling and automatic request retry.

    Since:
    1.0.0
    Author:
    Nestor Martourez, Sr Software and Data Streaming Engineer @ CodedStreams
    See Also:
    RemoteInferenceEngine, SageMakerRuntimeClient, SageMaker InvokeEndpoint API
    • Constructor Detail

      • SageMakerInferenceClient

        public SageMakerInferenceClient()
    • Method Detail

      • initialize

        public void initialize​(ModelConfig config)
                        throws InferenceException
        Initializes the SageMaker inference client with AWS configuration.

        Initialization process:

        1. Creates SageMakerRuntimeClient with configured region
        2. Sets up static credentials if provided in authConfig.apiKey
        3. Initializes ObjectMapper for JSON serialization
        4. Creates basic ModelMetadata from configuration

        Region Configuration:

        Currently defaults to us-east-1. Extend to support region configuration via model options if needed.

        Credential Parsing:

        If authConfig.apiKey is provided, it should be in format "ACCESS_KEY:SECRET_KEY". The colon separates access key from secret key.

        Specified by:
        initialize in interface InferenceEngine<Void>
        Overrides:
        initialize in class RemoteInferenceEngine
        Parameters:
        config - model configuration containing SageMaker endpoint name
        Throws:
        InferenceException - if client initialization fails or credentials are invalid
      • infer

        public InferenceResult infer​(Map<String,​Object> inputs)
                              throws InferenceException
        Invokes SageMaker endpoint for inference.

        Request flow:

        1. Serialize inputs to JSON using ObjectMapper
        2. Create InvokeEndpointRequest with endpoint name and JSON body
        3. Execute request via SageMakerRuntimeClient.invokeEndpoint(software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest)
        4. Parse response JSON back to Map
        5. Return InferenceResult with timing information

        SageMaker Response Format:

        SageMaker returns the raw model output as JSON. The structure depends on the model's output configuration. Common formats include:

        • Single value: {"prediction": 0.75}
        • Array: {"predictions": [0.1, 0.2, 0.7]}
        • Multiple outputs: {"class": "cat", "confidence": 0.92}
        Specified by:
        infer in interface InferenceEngine<Void>
        Specified by:
        infer in class RemoteInferenceEngine
        Parameters:
        inputs - map of input names to values (must be JSON-serializable)
        Returns:
        inference result containing SageMaker model outputs
        Throws:
        InferenceException - if SageMaker API call fails or response parsing fails
      • validateConnection

        public boolean validateConnection()
        Validates connection to SageMaker endpoint by sending a test inference.

        Sends a simple "ping" inference request to verify:

        • Endpoint exists and is accessible
        • Authentication works
        • Endpoint responds to inference requests

        Note: This may incur SageMaker charges for the test inference. Consider implementing a lighter validation if cost is a concern.

        Specified by:
        validateConnection in class RemoteInferenceEngine
        Returns:
        true if test inference succeeds
      • getMetadata

        public ModelMetadata getMetadata()
        Gets metadata about the SageMaker model.
        Returns:
        model metadata extracted during initialization
      • getModelConfig

        public ModelConfig getModelConfig()
        Gets the model configuration.
        Returns:
        the model configuration from inference config