onnxruntime/onnxruntime/test/python/onnxruntime_test_python_keras.py
Tianlei Wu 403f99cd77
Use yapf to format python (#3276)
Update ReformatSourcePython.bat to use YAPF to format python code, and add onnxruntime\test directory to be formatted.

Add onnxruntime\.style.yapf for configuration. The style is based on google, except max column width 120.

Format python scripts using ReformatSourcePython.bat.
2020-03-20 14:34:10 -07:00

81 lines
2.5 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -*- coding: UTF-8 -*-
# Taken from https://github.com/onnx/onnxmltools/blob/master/tests/end2end/test_custom_op.py.
import unittest
import os
import sys
import numpy as np
import onnxmltools
import onnxruntime as onnxrt
from keras import backend as K
from keras import Sequential
from keras.layers import Layer, Conv2D, MaxPooling2D
class ScaledTanh(Layer):
def __init__(self, alpha=1.0, beta=1.0, **kwargs):
super(ScaledTanh, self).__init__(**kwargs)
self.alpha = alpha
self.beta = beta
def build(self, input_shape):
super(ScaledTanh, self).build(input_shape)
def call(self, x):
return self.alpha * K.tanh(self.beta * x)
def compute_output_shape(self, input_shape):
return input_shape
def custom_activation(scope, operator, container):
# type:(ScopeBase, OperatorBase, ModelContainer) -> None
container.add_node('ScaledTanh',
operator.input_full_names,
operator.output_full_names,
op_version=1,
alpha=operator.original_operator.alpha,
beta=operator.original_operator.beta)
class TestInferenceSessionKeras(unittest.TestCase):
def testRunModelConv(self):
# keras model
N, C, H, W = 2, 3, 5, 5
x = np.random.rand(N, H, W, C).astype(np.float32, copy=False)
model = Sequential()
model.add(
Conv2D(2,
kernel_size=(1, 2),
strides=(1, 1),
padding='valid',
input_shape=(H, W, C),
data_format='channels_last'))
model.add(ScaledTanh(0.9, 2.0))
model.add(MaxPooling2D((2, 2), strides=(2, 2), data_format='channels_last'))
model.compile(optimizer='sgd', loss='mse')
actual = model.predict(x)
self.assertIsNotNone(actual)
# conversion
converted_model = onnxmltools.convert_keras(model, custom_conversion_functions={ScaledTanh: custom_activation})
self.assertIsNotNone(converted_model)
# runtime
content = converted_model.SerializeToString()
rt = onnxrt.InferenceSession(content)
input = {rt.get_inputs()[0].name: x}
actual_rt = rt.run(None, input)
self.assertEqual(len(actual_rt), 1)
np.testing.assert_allclose(actual, actual_rt[0], rtol=1e-05, atol=1e-08)
if __name__ == '__main__':
unittest.main(module=__name__, buffer=True)