From 439dbbada9546fa41d8abfe09d94b7129f934fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 29 Jan 2019 18:51:24 +0100 Subject: [PATCH] Adds OnnxTransformer to plug onnxruntime in sckit-learn's pipeline (#389) Useful for transfer learning --- cmake/onnxruntime_python.cmake | 7 ++ docs/python/conf.py | 2 +- onnxruntime/python/sklapi/__init__.py | 4 + onnxruntime/python/sklapi/onnx_transformer.py | 104 ++++++++++++++++++ .../python/onnxruntime_test_python_skl.py | 44 ++++++++ 5 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/python/sklapi/__init__.py create mode 100644 onnxruntime/python/sklapi/onnx_transformer.py create mode 100644 onnxruntime/test/python/onnxruntime_test_python_skl.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 4cb1effbe4..d00a4a0cef 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -115,6 +115,9 @@ file(GLOB onnxruntime_python_datasets_data "${ONNXRUNTIME_ROOT}/python/datasets/*.pb" "${ONNXRUNTIME_ROOT}/python/datasets/*.onnx" ) +file(GLOB onnxruntime_python_sklapi_srcs + "${ONNXRUNTIME_ROOT}/python/sklapi/*.py" +) # adjust based on what target/s onnxruntime_unittests.cmake created if (SingleUnitTestProject) @@ -129,6 +132,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/capi COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/datasets COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/sklapi COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -156,6 +160,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_datasets_data} $/onnxruntime/datasets/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_sklapi_srcs} + $/onnxruntime/sklapi/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_tools_srcs} $/onnxruntime/tools/ diff --git a/docs/python/conf.py b/docs/python/conf.py index 26bf21d55f..7e1a018234 100644 --- a/docs/python/conf.py +++ b/docs/python/conf.py @@ -52,7 +52,7 @@ source_suffix = ['.rst'] # , '.md'] # enables markdown output try: import docfx_markdown - extension.extend([ + extensions.extend([ "docfx_yaml.extension", "docfx_markdown", ]) diff --git a/onnxruntime/python/sklapi/__init__.py b/onnxruntime/python/sklapi/__init__.py new file mode 100644 index 0000000000..6f72adde6d --- /dev/null +++ b/onnxruntime/python/sklapi/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from .onnx_transformer import OnnxTransformer diff --git a/onnxruntime/python/sklapi/onnx_transformer.py b/onnxruntime/python/sklapi/onnx_transformer.py new file mode 100644 index 0000000000..b3823736a2 --- /dev/null +++ b/onnxruntime/python/sklapi/onnx_transformer.py @@ -0,0 +1,104 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +#-------------------------------------------------------------------------- +""" +Wraps runtime into a *scikit-learn* transformer. +""" +import numpy +import pandas +from sklearn.base import BaseEstimator, TransformerMixin +from .. import InferenceSession + + +class OnnxTransformer(BaseEstimator, TransformerMixin): + """ + Calls *onnxruntime* inference following *scikit-learn* API + so that it can be included in a *scikit-learn* pipeline. + """ + + def __init__(self, onnx_bytes, output_name=None): + """ + :param onnx_bytes: bytes + :param output_name: requested output name or None to request all and + have method *transform* to store all of them in a dataframe + """ + BaseEstimator.__init__(self) + TransformerMixin.__init__(self) + self.onnx_bytes = onnx_bytes + self.output_name = output_name + if not isinstance(onnx_bytes, bytes): + raise TypeError("onnx_bytes must be bytes to be pickled.") + + def fit(self, X=None, y=None, **fit_params): + """ + Loads the *ONNX* model. + + Parameters + ---------- + X : unused + y : unused + + Returns + ------- + self + """ + self.onnxrt_ = InferenceSession(self.onnx_bytes) + self.inputs_ = [_.name for _ in self.onnxrt_.get_inputs()] + return self + + def transform(self, X, y=None, **inputs): + """ + Runs the predictions. If *X* is a dataframe, + the function assumes every columns is a separate input, + otherwise, *X* is considered as a first input and *inputs* + can be used to specify extra inputs. + + Parameters + ---------- + X : iterable, data to process (or first input if several expected) + y : unused + inputs: additional inputs (input number >= 1) + + Returns + ------- + DataFrame + """ + if not hasattr(self, "onnxrt_"): + raise AttributeError("The transform must be fit first.") + rt_inputs = {} + if isinstance(X, pandas.DataFrame): + for c in X.columns: + rt_inputs[c] = X[c] + elif isinstance(X, numpy.ndarray): + rt_inputs[self.inputs_[0]] = X + + for k, v in inputs.items(): + rt_inputs[k] = v + + names = [self.output_name] if self.output_name else None + outputs = self.onnxrt_.run(names, rt_inputs) + + if self.output_name: + return outputs[0] + else: + if len(outputs) == 1: + return outputs[0] + else: + return pandas.DataFrame({k: v for k, v in zip(self.output_name, outputs)}) + + def fit_transform(self, X, y=None, **inputs): + """ + Loads the *ONNX* model and runs the predictions. + + Parameters + ---------- + X : iterable, data to process (or first input if several expected) + y : unused + inputs: additional inputs (input number >= 1) + + Returns + ------- + DataFrame + """ + return self.fit(X, y=y, **inputs).transform(X, y) diff --git a/onnxruntime/test/python/onnxruntime_test_python_skl.py b/onnxruntime/test/python/onnxruntime_test_python_skl.py new file mode 100644 index 0000000000..9089fae80d --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_skl.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# -*- coding: UTF-8 -*- +import unittest +import os +import sys +import numpy as np +import onnxruntime as onnxrt +from onnxruntime.capi._pybind_state import onnxruntime_ostream_redirect +from onnxruntime.sklapi import OnnxTransformer + + +class TestInferenceSessionSklearn(unittest.TestCase): + + def get_name(self, name): + if os.path.exists(name): + return name + rel = os.path.join("testdata", name) + if os.path.exists(rel): + return rel + this = os.path.dirname(__file__) + data = os.path.join(this, "..", "testdata") + res = os.path.join(data, name) + if os.path.exists(res): + return res + raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res)) + + def test_transform(self): + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + name = self.get_name("mul_1.pb") + with open(name, "rb") as f: + content = f.read() + + tr = OnnxTransformer(content) + tr.fit() + res = tr.transform(x) + exp = np.array([[ 1., 4.], [ 9., 16.], [25., 36.]], dtype=np.float32) + assert list(res.ravel()) == list(exp.ravel()) + + + +if __name__ == '__main__': + unittest.main()