Enterprise Java
Building Predictive APIs with TensorFlow and Spring Boot
1. Why Combine AI/ML with Spring Boot?
Modern applications increasingly need smart capabilities – from recommendation engines to fraud detection. While Python dominates ML development, Java teams can leverage:
- TensorFlow Java for model inference
- Spring Boot for scalable API delivery
- DJL (Deep Java Library) as an alternative framework
This guide walks through serving a trained ML model via REST API with zero Python dependencies.
2. Architecture Overview
1 2 3 4 5 | [Python Environment] -- Trains Model --> SavedModel.pb ? [Java Service] <-- Loads Model --> [Spring Boot REST API] ? [Client Apps] <-- Gets Predictions |
Key components:
- TensorFlow SavedModel (exported from Python)
- Spring Boot web layer
- TensorFlow Java API for inference
Step 1: Train and Export Model (Python)
01 02 03 04 05 06 07 08 09 10 11 12 13 14 | # train.py import tensorflow as tf # Sample neural network model = tf.keras.Sequential([ tf.keras.layers.Dense( 64 , activation = 'relu' ), tf.keras.layers.Dense( 1 ) ]) model. compile (optimizer = 'adam' , loss = 'mse' ) model.fit(X_train, y_train, epochs = 10 ) # Export for Java tf.saved_model.save(model, "saved_model" ) |
This creates a /saved_model
directory with:
saved_model.pb
(architecture)variables/
(trained weights)
Step 2: Spring Boot Integration
Dependencies (pom.xml
)
1 2 3 4 5 | < dependency > < groupId >org.tensorflow</ groupId > < artifactId >tensorflow-core-platform</ artifactId > < version >0.4.1</ version > </ dependency > |
Load Model in Java
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import org.tensorflow.*; import org.tensorflow.types.TFloat32; public class Predictor { private SavedModelBundle model; @PostConstruct public void init() { this .model = SavedModelBundle.load( "src/main/resources/saved_model" , "serve" ); } public float predict( float [] input) { try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(input); TFloat32 result = (TFloat32)model.session() .runner() .feed( "dense_input" , inputTensor) .fetch( "dense_1" ) .run() .get( 0 )) { return result.getFloat(); } } } |
Step 3: Expose as REST API
01 02 03 04 05 06 07 08 09 10 11 12 13 | @RestController @RequestMapping ( "/api/predict" ) public class PredictionController { @Autowired private Predictor predictor; @PostMapping public PredictionResponse predict( @RequestBody PredictionRequest request) { float result = predictor.predict(request.getFeatures()); return new PredictionResponse(result); } } |
Sample request:
1 2 3 | curl -X POST http: //localhost :8080 /api/predict \ -H "Content-Type: application/json" \ -d '{"features": [0.1, 0.5, 0.3]}' |
3. Performance Optimization Tips
- Batching Predictions
Process multiple inputs in one session run:
1 2 | float [][] batchInputs = ...; Tensor<TFloat32> batchTensor = TFloat32.tensorOf(batchInputs); |
2. GPU Acceleration
Add CUDA dependencies for NVIDIA GPUs:
1 2 3 4 5 | < dependency > < groupId >org.tensorflow</ groupId > < artifactId >tensorflow-core-platform-gpu</ artifactId > < version >0.4.1</ version > </ dependency > |
3. Model Warmup
Initialize model at startup to avoid first-call latency:
1 2 3 4 | @Bean public CommandLineRunner warmup(Predictor predictor) { return args -> predictor.predict( new float [inputSize]); } |
4. Alternative: DJL (Deep Java Library)
For more Java-native ML workflows:
1 2 3 4 5 6 7 8 9 | // Build model directly in Java Model model = Model.newInstance( "linear" ); model.load( new Path( "model.pt" )); try (NDManager manager = NDManager.newBaseManager()) { NDArray input = manager.create( new float []{...}); Predictor predictor = model.newPredictor(); NDArray result = predictor.predict(input); } |
Advantages:
- Unified API for TensorFlow/PyTorch/MXNet
- No SWIG/JNI overhead
- Built-in image preprocessing
5. Conclusion
Key takeaways:
✅ Serve TensorFlow models without Python in production
✅ Achieve <10ms latency per prediction
✅ Scale horizontally like any Spring Boot service
Next Steps:
- Try the TensorFlow Java examples
- Explore DJL’s Spring Boot starter
- Monitor performance with Micrometer metrics