Input/output control with TensorFlow and AWS SageMaker

18/02/20, by Dario Ferrer

Inference TensorFlow input

One area of TensorFlow use which causes great confusion is how to pass input to the TensorFlow framework for inferencing. An excellent approach to this problem is to employ an input_signature in the model - this allows an amount of input validation, and also permits flexible logic at data call time, thus eliminating a separate Transform step in a typical ETL pipeline. The input_signature is described in the TensorFlow documentation thusly:

Input_signature: A possibly nested sequence of tf.TensorSpec objects specifying the shapes and dtypes of the Tensors that will be supplied to this function. If None, a separate function is instantiated for each inferred input signature. If input_signature is specified, every input to func must be a Tensor, and func cannot accept **kwargs.

This is great for the many use cases where we can simply ‘load and go’ but in the real world it’s more likely that the data is not in quite the right shape - it would be excellent if we could perform further data manipulation right before, after or even during TensorFlow predictions at serving time.

Where is the rest of TensorFlow?

The aim of this post is just to clarify TensorFlow input/output manipulation. and continues where this post left off. There we discuss the setup of a SageMaker and TensorFlow deployment in depth.

SageMaker input/output manipulation

AWS SageMaker serving containers (used in endpoints and batch jobs) will look for a file called in the TensorFlow exported model. If it exists, it will look for 3 specific function names inside:

  • input_handler: If defined, any input passed to the model will be passed through this function, and whatever the function returns will be passed to the TensorFlow model.

  • output_handler: If defined, the output of the TensorFlow prediction will be passed through this function and the function’s output will be sent back to the client as the TensorFlow prediction result.

  • handler: This is a special function intended for more generic I/O beyond the previous two. Here you are responsible for the complete request from the client, through a TensorFlow model, and back again. A typical use case here is when you want to call a second or third model, or convert between different data encodings (REST vs. gRPC).

This final method is the most powerful and is the approach we have taken for our client.

Match the input ID’s with the inferences

I’m going to show a very simple example of how to use the handler function in the file and how and exactly where to place that file in our exported model tarball.

The case is fairly simple; we wish to pass deserialised data in the form of a list of JSON objects to the live SageMaker endpoints. This is our standard way of calling the TensorFlow model for predictions from our application; the input data looks like this:

data1 = {
    'quotation_id': 100007527745,
    'year': '2019',
    'month': '01',
    'stage_id': 1,
    'company_id': 2,
    'product_modality_id': 2,
    'cover_3': 1,
data2 = {
    'quotation_id': 100006042938,
    'year': '2018',
    'stage_id': 1,
    'company_id': 3,
    'product_modality_id': 7,
    'cover_3': 0,
data_list = [data1, data2]
input_data = {'signature_name': 'predict','instances': data_list}

In the last line of the snippet, we are calling our live SageMaker endpoint to get those 2 predictions.

As you can see, we have a field called quotation_id which gives us a unique key by which to match each prediction to its corresponding input, and so the way we do it is by creating a tuple of quotation_ids when submitting the input to TensorFlow. The response from TensorFlow is an ordered list of predictions so we are able to modify the predictions by adding an extra field from our saved list; the quotation_id. This is directly achievable in memory without the need of any saved state because the handler function is processing both, the input and the output in the same call. Here’s the code:

def handler(data, context):
    """Handle request.
        data (obj): the request data
        context (Context): an object containing request and configuration details
        (bytes, string): data to return to client, (optional) response content type
    processed_input, list_of_ids = _process_input(data, context)
    response =, data=processed_input)
    output = _process_output(response, context)
    return zip(output, list_of_ids)

This is a simplified version with some business logic removed, but the concept is the same - we process the input, pass it to TensorFlow, process the output, then send the endpoint response back to the client.

CSV deserialisation

In our use case, we did not only need to match input fields with output fields, we also needed to be able to feed the data in multiple formats. Apart from the aforementioned JSON format, we also use a bare CSV format for batch inference jobs. Our CSV is a bit special as it uses ; as field separator and it is encoded in latin-1. Let’s note that you can also deal with different data formats in the input_signature but as we needed to further manipulate the data we made it part of our file.

Let’s have a look at the input processing function:

def set_type(value, default_val):
    """Just a naive type setting function
    if type(default_val) is float:
        if value == '':
            return float(0)
        return float(value)
    elif type(default_val) is int:
        if value == '':
            return int(0)
        return int(value)
    elif type(default_val) is str:
        return str(value)

def _process_input(data, context):
    """ Pre-process request input before it is sent to TensorFlow Serving REST API
        data (obj): the request data, in format of dict or string
        context (Context): an object containing request and configuration details
        (dict): a JSON-serializable dict that contains request body and headers
    ## This is a dictionary of the columns present on the CSV data mapped to each default value
    dataset_column_names_and_defaults = {'quotation_id':0,'year':'','month':'','stage_id':0,'company_id':0,'product_modality_id':0,'cover_3':0}
    ## Define the list of columns present in the CSV dataset
    dataset_column_names = list(dataset_column_names_and_defaults.keys())

    if context.request_content_type == 'application/json':
        # pass through json (assumes it's correctly formed)
        d ='utf-8')
        return d, [] if len(d) else ''

    if context.request_content_type == 'text/csv':
        # very simple csv handler
        data ='latin-1')
        data_list = list(csv.reader(data.split('\n'), delimiter=';'))
        list_of_instances = [ {column_name: set_type(value, default_val) for column_name, value, default_val in zip(dataset_column_names, item, dataset_column_defaults)} for item in data_list]

        # Delete last empty instance
        list_of_ids = [ item['quotation_id'] for item in list_of_instances ]
        # Remove price
        for instance in list_of_instances:

        body = json.dumps({
            'signature_name': 'predict',
            'instances': list_of_instances

        return body, list_of_ids
    raise ValueError('{{"error": "unsupported content type {}"}}'.format(
        context.request_content_type or "unknown"))

As you see, we accept 2 types of input data, and we process it in different ways. In each case we return the data in the TensorFlow’s input_signature required format and also the list_of_ids as a list.

We pass the processed data to TensorFlow, get the inferences list and then pass it to the output processing function. The output function is simple; it checks that TensorFlow responded with an neat HTTP 200 and raises an error if this is not the case:

def _process_output(data, context):
    """Post-process TensorFlow Serving output before it is returned to the client.
        data (obj): the TensorFlow serving response
        context (Context): an object containing request and configuration details
        (bytes, string): data to return to client, response content type
    if data.status_code != 200:
        raise ValueError(data.content.decode('utf-8'))

    response_content_type = context.accept_header
    prediction = data.content
    return prediction, response_content_type

Ordered, labeled results

Eventually we can get our predictions using either the live SageMaker endpoint or by creating a batch job and passing the input as CSV files in an S3 bucket. What we get is an ordered list of inferences with the quotation ID as an output field. The last line of our handler function (return zip(output, list_of_ids)) is the responsible of glueing together the quotation ID’s to the predictions.

An example of the actual output looks like this:

{"predictions": [
    [100003822499, 1198.35742],
    [100005706767, 1000.75159],
    [100005489495, 489.44455],
    [100006569319, 367.685852],
    [100003898000, 508.068024],
    [100004221265, 474.993866]

The quotation ID’s are the first element of each result tuple and the quotation itself, the TensorFlow prediction, is the second value. location

TensorFlow is a complex piece of software and it can be difficult to get things working. One of the areas where we initially stumbled was regarding the handler script, so this is exactly where the file has to be:

  • The file must be called exactly
  • The file must be placed in the exported model tarball, in the code dir

As we’re using SageMaker to train and export the model using the SageMaker fit function in script mode, we copy the file to its final location from the entrypoint script itself. We must first add the file in the custom code location by adding it to the list of dependencies:

insurance_estimator = TensorFlow(entry_point='',
                                    'enabled': False

And then at the end of the file the following lines:

    ## Copy the custom input output tooling file
    code_dir = os.path.join(local_model_dir, 'code')
    copyfile('', os.path.join(code_dir,''))

And that’s it, the file is now part of the exported tarball without the need to un-compress + add + compress it again.

Get in touch

If you’ve been interested by what you’ve read and would like to talk to someone about how Machine Learning could give your business the edge, please get in touch!

18/02/20 Input/output control with TensorFlow and AWS SageMaker, by Dario Ferrer

comments powered by Disqus