Enterprise Java

Building Predictive APIs with TensorFlow and Spring Boot

1. Why Combine AI/ML with Spring Boot?

Modern applications increasingly need smart capabilities – from recommendation engines to fraud detection. While Python dominates ML development, Java teams can leverage:

  • TensorFlow Java for model inference
  • Spring Boot for scalable API delivery
  • DJL (Deep Java Library) as an alternative framework

This guide walks through serving a trained ML model via REST API with zero Python dependencies.

2. Architecture Overview

1
2
3
4
5
[Python Environment] -- Trains Model --> SavedModel.pb
                      ?
[Java Service] <-- Loads Model --> [Spring Boot REST API]
                      ?
[Client Apps] <-- Gets Predictions

Key components:

  1. TensorFlow SavedModel (exported from Python)
  2. Spring Boot web layer
  3. TensorFlow Java API for inference

Step 1: Train and Export Model (Python)

01
02
03
04
05
06
07
08
09
10
11
12
13
14
# train.py
import tensorflow as tf
 
# Sample neural network
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])
 
model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, epochs=10)
 
# Export for Java
tf.saved_model.save(model, "saved_model")

This creates a /saved_model directory with:

  • saved_model.pb (architecture)
  • variables/ (trained weights)

Step 2: Spring Boot Integration

Dependencies (pom.xml)

1
2
3
4
5
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-platform</artifactId>
    <version>0.4.1</version>
</dependency>

Load Model in Java

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import org.tensorflow.*;
import org.tensorflow.types.TFloat32;
 
public class Predictor {
    private SavedModelBundle model;
     
    @PostConstruct
    public void init() {
        this.model = SavedModelBundle.load(
            "src/main/resources/saved_model",
            "serve"
        );
    }
     
    public float predict(float[] input) {
        try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(input);
             TFloat32 result = (TFloat32)model.session()
                .runner()
                .feed("dense_input", inputTensor)
                .fetch("dense_1")
                .run()
                .get(0)) {
             
            return result.getFloat();
        }
    }
}

Step 3: Expose as REST API

01
02
03
04
05
06
07
08
09
10
11
12
13
@RestController
@RequestMapping("/api/predict")
public class PredictionController {
     
    @Autowired
    private Predictor predictor;
     
    @PostMapping
    public PredictionResponse predict(@RequestBody PredictionRequest request) {
        float result = predictor.predict(request.getFeatures());
        return new PredictionResponse(result);
    }
}

Sample request:

1
2
3
curl -X POST http://localhost:8080/api/predict \
  -H "Content-Type: application/json" \
  -d '{"features": [0.1, 0.5, 0.3]}'

3. Performance Optimization Tips

  1. Batching Predictions
    Process multiple inputs in one session run:
1
2
float[][] batchInputs = ...;
Tensor<TFloat32> batchTensor = TFloat32.tensorOf(batchInputs);

2. GPU Acceleration
Add CUDA dependencies for NVIDIA GPUs:

1
2
3
4
5
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-platform-gpu</artifactId>
    <version>0.4.1</version>
</dependency>

3. Model Warmup
Initialize model at startup to avoid first-call latency:

1
2
3
4
@Bean
public CommandLineRunner warmup(Predictor predictor) {
    return args -> predictor.predict(new float[inputSize]);
}

4. Alternative: DJL (Deep Java Library)

For more Java-native ML workflows:

1
2
3
4
5
6
7
8
9
// Build model directly in Java
Model model = Model.newInstance("linear");
model.load(new Path("model.pt"));
 
try(NDManager manager = NDManager.newBaseManager()) {
    NDArray input = manager.create(new float[]{...});
    Predictor predictor = model.newPredictor();
    NDArray result = predictor.predict(input);
}

Advantages:

  • Unified API for TensorFlow/PyTorch/MXNet
  • No SWIG/JNI overhead
  • Built-in image preprocessing

5. Conclusion

Key takeaways:
✅ Serve TensorFlow models without Python in production
✅ Achieve <10ms latency per prediction
✅ Scale horizontally like any Spring Boot service

Next Steps:

  1. Try the TensorFlow Java examples
  2. Explore DJL’s Spring Boot starter
  3. Monitor performance with Micrometer metrics
Do you want to know how to develop your skillset to become a Java Rockstar?
Subscribe to our newsletter to start Rocking right now!
To get you started we give you our best selling eBooks for FREE!
1. JPA Mini Book
2. JVM Troubleshooting Guide
3. JUnit Tutorial for Unit Testing
4. Java Annotations Tutorial
5. Java Interview Questions
6. Spring Interview Questions
7. Android UI Design
and many more ....
I agree to the Terms and Privacy Policy

Eleftheria Drosopoulou

Eleftheria is an Experienced Business Analyst with a robust background in the computer software industry. Proficient in Computer Software Training, Digital Marketing, HTML Scripting, and Microsoft Office, they bring a wealth of technical skills to the table. Additionally, she has a love for writing articles on various tech subjects, showcasing a talent for translating complex concepts into accessible content.
Subscribe
Notify of
guest


This site uses Akismet to reduce spam. Learn how your comment data is processed.

0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
Back to top button