mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Update ReformatSourcePython.bat to use YAPF to format python code, and add onnxruntime\test directory to be formatted. Add onnxruntime\.style.yapf for configuration. The style is based on google, except max column width 120. Format python scripts using ReformatSourcePython.bat.
81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
# Taken from https://github.com/onnx/onnxmltools/blob/master/tests/end2end/test_custom_op.py.
|
|
import unittest
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import onnxmltools
|
|
import onnxruntime as onnxrt
|
|
from keras import backend as K
|
|
from keras import Sequential
|
|
from keras.layers import Layer, Conv2D, MaxPooling2D
|
|
|
|
|
|
class ScaledTanh(Layer):
|
|
|
|
def __init__(self, alpha=1.0, beta=1.0, **kwargs):
|
|
super(ScaledTanh, self).__init__(**kwargs)
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
|
|
def build(self, input_shape):
|
|
super(ScaledTanh, self).build(input_shape)
|
|
|
|
def call(self, x):
|
|
return self.alpha * K.tanh(self.beta * x)
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return input_shape
|
|
|
|
|
|
def custom_activation(scope, operator, container):
|
|
# type:(ScopeBase, OperatorBase, ModelContainer) -> None
|
|
container.add_node('ScaledTanh',
|
|
operator.input_full_names,
|
|
operator.output_full_names,
|
|
op_version=1,
|
|
alpha=operator.original_operator.alpha,
|
|
beta=operator.original_operator.beta)
|
|
|
|
|
|
class TestInferenceSessionKeras(unittest.TestCase):
|
|
|
|
def testRunModelConv(self):
|
|
|
|
# keras model
|
|
N, C, H, W = 2, 3, 5, 5
|
|
x = np.random.rand(N, H, W, C).astype(np.float32, copy=False)
|
|
|
|
model = Sequential()
|
|
model.add(
|
|
Conv2D(2,
|
|
kernel_size=(1, 2),
|
|
strides=(1, 1),
|
|
padding='valid',
|
|
input_shape=(H, W, C),
|
|
data_format='channels_last'))
|
|
model.add(ScaledTanh(0.9, 2.0))
|
|
model.add(MaxPooling2D((2, 2), strides=(2, 2), data_format='channels_last'))
|
|
|
|
model.compile(optimizer='sgd', loss='mse')
|
|
actual = model.predict(x)
|
|
self.assertIsNotNone(actual)
|
|
|
|
# conversion
|
|
converted_model = onnxmltools.convert_keras(model, custom_conversion_functions={ScaledTanh: custom_activation})
|
|
self.assertIsNotNone(converted_model)
|
|
|
|
# runtime
|
|
content = converted_model.SerializeToString()
|
|
rt = onnxrt.InferenceSession(content)
|
|
input = {rt.get_inputs()[0].name: x}
|
|
actual_rt = rt.run(None, input)
|
|
self.assertEqual(len(actual_rt), 1)
|
|
np.testing.assert_allclose(actual, actual_rt[0], rtol=1e-05, atol=1e-08)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(module=__name__, buffer=True)
|