mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
remove OnnxTransformer (#554)
This commit is contained in:
parent
bf43ac41aa
commit
c6d39b60cd
4 changed files with 0 additions and 159 deletions
|
|
@ -118,9 +118,6 @@ file(GLOB onnxruntime_python_datasets_data
|
|||
"${ONNXRUNTIME_ROOT}/python/datasets/*.pb"
|
||||
"${ONNXRUNTIME_ROOT}/python/datasets/*.onnx"
|
||||
)
|
||||
file(GLOB onnxruntime_python_sklapi_srcs
|
||||
"${ONNXRUNTIME_ROOT}/python/sklapi/*.py"
|
||||
)
|
||||
|
||||
# adjust based on what target/s onnxruntime_unittests.cmake created
|
||||
if (SingleUnitTestProject)
|
||||
|
|
@ -135,7 +132,6 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/capi
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/datasets
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/sklapi
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${ONNXRUNTIME_ROOT}/__init__.py
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/
|
||||
|
|
@ -163,9 +159,6 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_datasets_data}
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/datasets/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_sklapi_srcs}
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/sklapi/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_tools_srcs}
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/tools/
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .onnx_transformer import OnnxTransformer
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
#-------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#--------------------------------------------------------------------------
|
||||
"""
|
||||
Wraps runtime into a *scikit-learn* transformer.
|
||||
"""
|
||||
import numpy
|
||||
import pandas
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
from .. import InferenceSession
|
||||
|
||||
|
||||
class OnnxTransformer(BaseEstimator, TransformerMixin):
|
||||
"""
|
||||
Calls *onnxruntime* inference following *scikit-learn* API
|
||||
so that it can be included in a *scikit-learn* pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_bytes, output_name=None):
|
||||
"""
|
||||
:param onnx_bytes: bytes
|
||||
:param output_name: requested output name or None to request all and
|
||||
have method *transform* to store all of them in a dataframe
|
||||
"""
|
||||
BaseEstimator.__init__(self)
|
||||
TransformerMixin.__init__(self)
|
||||
self.onnx_bytes = onnx_bytes
|
||||
self.output_name = output_name
|
||||
if not isinstance(onnx_bytes, bytes):
|
||||
raise TypeError("onnx_bytes must be bytes to be pickled.")
|
||||
|
||||
def fit(self, X=None, y=None, **fit_params):
|
||||
"""
|
||||
Loads the *ONNX* model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : unused
|
||||
y : unused
|
||||
|
||||
Returns
|
||||
-------
|
||||
self
|
||||
"""
|
||||
self.onnxrt_ = InferenceSession(self.onnx_bytes)
|
||||
self.inputs_ = [_.name for _ in self.onnxrt_.get_inputs()]
|
||||
return self
|
||||
|
||||
def transform(self, X, y=None, **inputs):
|
||||
"""
|
||||
Runs the predictions. If *X* is a dataframe,
|
||||
the function assumes every columns is a separate input,
|
||||
otherwise, *X* is considered as a first input and *inputs*
|
||||
can be used to specify extra inputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : iterable, data to process (or first input if several expected)
|
||||
y : unused
|
||||
inputs: additional inputs (input number >= 1)
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataFrame
|
||||
"""
|
||||
if not hasattr(self, "onnxrt_"):
|
||||
raise AttributeError("The transform must be fit first.")
|
||||
rt_inputs = {}
|
||||
if isinstance(X, pandas.DataFrame):
|
||||
for c in X.columns:
|
||||
rt_inputs[c] = X[c]
|
||||
elif isinstance(X, numpy.ndarray):
|
||||
rt_inputs[self.inputs_[0]] = X
|
||||
|
||||
for k, v in inputs.items():
|
||||
rt_inputs[k] = v
|
||||
|
||||
names = [self.output_name] if self.output_name else None
|
||||
outputs = self.onnxrt_.run(names, rt_inputs)
|
||||
|
||||
if self.output_name:
|
||||
return outputs[0]
|
||||
else:
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
else:
|
||||
return pandas.DataFrame({k: v for k, v in zip(self.output_name, outputs)})
|
||||
|
||||
def fit_transform(self, X, y=None, **inputs):
|
||||
"""
|
||||
Loads the *ONNX* model and runs the predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : iterable, data to process (or first input if several expected)
|
||||
y : unused
|
||||
inputs: additional inputs (input number >= 1)
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataFrame
|
||||
"""
|
||||
return self.fit(X, y=y, **inputs).transform(X, y)
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# -*- coding: UTF-8 -*-
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import onnxruntime as onnxrt
|
||||
from onnxruntime.capi._pybind_state import onnxruntime_ostream_redirect
|
||||
from onnxruntime.sklapi import OnnxTransformer
|
||||
|
||||
|
||||
class TestInferenceSessionSklearn(unittest.TestCase):
|
||||
|
||||
def get_name(self, name):
|
||||
if os.path.exists(name):
|
||||
return name
|
||||
rel = os.path.join("testdata", name)
|
||||
if os.path.exists(rel):
|
||||
return rel
|
||||
this = os.path.dirname(__file__)
|
||||
data = os.path.join(this, "..", "testdata")
|
||||
res = os.path.join(data, name)
|
||||
if os.path.exists(res):
|
||||
return res
|
||||
raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res))
|
||||
|
||||
def test_transform(self):
|
||||
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
|
||||
name = self.get_name("mul_1.pb")
|
||||
with open(name, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
tr = OnnxTransformer(content)
|
||||
tr.fit()
|
||||
res = tr.transform(x)
|
||||
exp = np.array([[ 1., 4.], [ 9., 16.], [25., 36.]], dtype=np.float32)
|
||||
assert list(res.ravel()) == list(exp.ravel())
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue