Serve TensorFlow Models in Java

TensorFlow is a famous machine learning framework from Google and a must to know asset for machine learning engineers. Even though Python is recommended to build TensorFlow models, Google offers Java API to use TensorFlow in Java. Still, Python is the easiest language to build TensorFlow models, even for Java developers (learn Python, my friend). However, enterprise applications developed in Java may require the artificial intelligence offered by a trained TensorFlow model. In this article, you will learn how to load and use a simple TensorFlow model exported from Python.



Prerequisite

Step 1:
Create a Python script named train_tf_model.py with the following content and run the script using Python 3 to save the model in /tmp/tf_add_model. If you are a Windows user, change the /tmp/tf_add_model to a Windows-specific path.
#!/usr/bin/env python3
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model.utils import build_tensor_info

x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')

# This is our model
add = tf.add(x, y, name='ans')

with tf.Session() as sess:
    # Pick out the model input and output
    x_tensor = sess.graph.get_tensor_by_name('x:0')
    y_tensor = sess.graph.get_tensor_by_name('y:0')
    ans_tensor = sess.graph.get_tensor_by_name('ans:0')

    x_info = build_tensor_info(x_tensor)
    y_info = build_tensor_info(y_tensor)
    ans_info = build_tensor_info(ans_tensor)

    # Create a signature definition for tfserving
    signature_definition = signature_def_utils.build_signature_def(
        inputs={'x': x_info, 'y': y_info},
        outputs={'ans': ans_info},
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

    builder = saved_model_builder.SavedModelBuilder('/tmp/tf_add_model')

    builder.add_meta_graph_and_variables(
        sess, [tag_constants.SERVING],
        signature_def_map={
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                signature_definition
        })

    # Save the model so we can serve it with a model server :)
    builder.save()
The TensorFlow model built in the above script

In this code, we create a Tensorflow model which receives two inputs x & y and produce an output ans. The constructed model is saved to /tmp/tf_add_model path. Note that all x, y, and ans are int32 types so that we need to map them to int later in Java.

Running this script using Python 3 will create the following files in the tmp directory.
tmp
└── tf_add_model
    ├── saved_model.pb
    └── variables


Step 2:
Create a new Apache Maven project in your favorite IDE with a group id: com.javahelps.tensorflow and an artifact id model-server.


Step 3:
Add the tensorflow dependency to the pom.xml file as shown below:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.javahelps.tensorflow</groupId>
    <artifactId>model-server</artifactId>
    <version>1.0-SNAPSHOT</version>
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>11</source>
                    <target>11</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

    <dependencies>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.13.1</version>
        </dependency>
    </dependencies>

</project>

Step 4:
Create a new package: com.javahelps.tensorflow.moderserver in the src/main/java folder.

Step 5:
Create a new class ModelServer.java with the following code:
package com.javahelps.tensorflow.moderserver;

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlowException;

import java.util.List;

public class ModelServer {

    public static void main(String[] args) {
        try (SavedModelBundle savedModelBundle = SavedModelBundle.load("/tmp/tf_add_model", "serve")) {

            try (Session session = savedModelBundle.session()) {
                Session.Runner runner = session.runner();
                runner.feed("x", Tensor.create(10));
                runner.feed("y", Tensor.create(20));

                List<Tensor<?>> tensors = runner.fetch("ans").run();
                System.out.println("Answer is: " + tensors.get(0).intValue());
            }

        } catch (TensorFlowException ex) {
            ex.printStackTrace();
        }
    }
}
In this code, SavedModelBundle and Session are java.lang.AutoClosable classes so that they are created in try-with-resource blocks. The SavedModelBundle object is created from the TensorFlow model saved in /tmp/tf_add_model folder. Session runner is fed with x and y inputs. Note that these input parameters must be int values because we defined them to be int32 in the Python script. Similarly, the output ans is fetched to an int value. Violating types will cause to java.lang.IllegalArgumentException at the runtime.

Running this code produces an output 30 because we feed 10 and 20 to the TensorFlow model which returns the sum of its inputs. Though this example is purposely made simple to add two numbers, you can use a complex machine learning model which requires multiple vectors and produces one or more than one vectors (vectors are mapped to arrays in Java).

I hope this short article is clear enough to give you the idea. If you have any questions, feel free to comment below.

Find the project @ GitHub (Python script and the saved model also available)
Latest
Previous
Next Post »

Contact Form

Name

Email *

Message *