Adds OnnxTransformer to plug onnxruntime in sckit-learn's pipeline (#389)

Useful for transfer learning
This commit is contained in:
Xavier Dupré 2019-01-29 18:51:24 +01:00 committed by GitHub
parent 7c21c15732
commit 439dbbada9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 160 additions and 1 deletions

View file

@ -115,6 +115,9 @@ file(GLOB onnxruntime_python_datasets_data
"${ONNXRUNTIME_ROOT}/python/datasets/*.pb"
"${ONNXRUNTIME_ROOT}/python/datasets/*.onnx"
)
file(GLOB onnxruntime_python_sklapi_srcs
"${ONNXRUNTIME_ROOT}/python/sklapi/*.py"
)
# adjust based on what target/s onnxruntime_unittests.cmake created
if (SingleUnitTestProject)
@ -129,6 +132,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/capi
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/datasets
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/sklapi
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/
@ -156,6 +160,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_datasets_data}
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/datasets/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_sklapi_srcs}
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/sklapi/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_tools_srcs}
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools/

View file

@ -52,7 +52,7 @@ source_suffix = ['.rst'] # , '.md']
# enables markdown output
try:
import docfx_markdown
extension.extend([
extensions.extend([
"docfx_yaml.extension",
"docfx_markdown",
])

View file

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from .onnx_transformer import OnnxTransformer

View file

@ -0,0 +1,104 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
"""
Wraps runtime into a *scikit-learn* transformer.
"""
import numpy
import pandas
from sklearn.base import BaseEstimator, TransformerMixin
from .. import InferenceSession
class OnnxTransformer(BaseEstimator, TransformerMixin):
"""
Calls *onnxruntime* inference following *scikit-learn* API
so that it can be included in a *scikit-learn* pipeline.
"""
def __init__(self, onnx_bytes, output_name=None):
"""
:param onnx_bytes: bytes
:param output_name: requested output name or None to request all and
have method *transform* to store all of them in a dataframe
"""
BaseEstimator.__init__(self)
TransformerMixin.__init__(self)
self.onnx_bytes = onnx_bytes
self.output_name = output_name
if not isinstance(onnx_bytes, bytes):
raise TypeError("onnx_bytes must be bytes to be pickled.")
def fit(self, X=None, y=None, **fit_params):
"""
Loads the *ONNX* model.
Parameters
----------
X : unused
y : unused
Returns
-------
self
"""
self.onnxrt_ = InferenceSession(self.onnx_bytes)
self.inputs_ = [_.name for _ in self.onnxrt_.get_inputs()]
return self
def transform(self, X, y=None, **inputs):
"""
Runs the predictions. If *X* is a dataframe,
the function assumes every columns is a separate input,
otherwise, *X* is considered as a first input and *inputs*
can be used to specify extra inputs.
Parameters
----------
X : iterable, data to process (or first input if several expected)
y : unused
inputs: additional inputs (input number >= 1)
Returns
-------
DataFrame
"""
if not hasattr(self, "onnxrt_"):
raise AttributeError("The transform must be fit first.")
rt_inputs = {}
if isinstance(X, pandas.DataFrame):
for c in X.columns:
rt_inputs[c] = X[c]
elif isinstance(X, numpy.ndarray):
rt_inputs[self.inputs_[0]] = X
for k, v in inputs.items():
rt_inputs[k] = v
names = [self.output_name] if self.output_name else None
outputs = self.onnxrt_.run(names, rt_inputs)
if self.output_name:
return outputs[0]
else:
if len(outputs) == 1:
return outputs[0]
else:
return pandas.DataFrame({k: v for k, v in zip(self.output_name, outputs)})
def fit_transform(self, X, y=None, **inputs):
"""
Loads the *ONNX* model and runs the predictions.
Parameters
----------
X : iterable, data to process (or first input if several expected)
y : unused
inputs: additional inputs (input number >= 1)
Returns
-------
DataFrame
"""
return self.fit(X, y=y, **inputs).transform(X, y)

View file

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import unittest
import os
import sys
import numpy as np
import onnxruntime as onnxrt
from onnxruntime.capi._pybind_state import onnxruntime_ostream_redirect
from onnxruntime.sklapi import OnnxTransformer
class TestInferenceSessionSklearn(unittest.TestCase):
def get_name(self, name):
if os.path.exists(name):
return name
rel = os.path.join("testdata", name)
if os.path.exists(rel):
return rel
this = os.path.dirname(__file__)
data = os.path.join(this, "..", "testdata")
res = os.path.join(data, name)
if os.path.exists(res):
return res
raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res))
def test_transform(self):
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
name = self.get_name("mul_1.pb")
with open(name, "rb") as f:
content = f.read()
tr = OnnxTransformer(content)
tr.fit()
res = tr.transform(x)
exp = np.array([[ 1., 4.], [ 9., 16.], [25., 36.]], dtype=np.float32)
assert list(res.ravel()) == list(exp.ravel())
if __name__ == '__main__':
unittest.main()