mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
update from keras2onnx to tf2onnx (#15162)
This commit is contained in:
parent
1b730c3d11
commit
ebc4edfe7a
6 changed files with 24 additions and 24 deletions
8
setup.py
8
setup.py
|
|
@ -114,7 +114,6 @@ _deps = [
|
|||
"jax>=0.2.8",
|
||||
"jaxlib>=0.1.65",
|
||||
"jieba",
|
||||
"keras2onnx",
|
||||
"nltk",
|
||||
"numpy>=1.17",
|
||||
"onnxconverter-common",
|
||||
|
|
@ -147,6 +146,7 @@ _deps = [
|
|||
"starlette",
|
||||
"tensorflow-cpu>=2.3",
|
||||
"tensorflow>=2.3",
|
||||
"tf2onnx",
|
||||
"timeout-decorator",
|
||||
"timm",
|
||||
"tokenizers>=0.10.1",
|
||||
|
|
@ -229,8 +229,8 @@ extras = {}
|
|||
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
|
||||
extras["sklearn"] = deps_list("scikit-learn")
|
||||
|
||||
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx")
|
||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx")
|
||||
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
|
||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")
|
||||
|
||||
extras["torch"] = deps_list("torch")
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ else:
|
|||
|
||||
extras["tokenizers"] = deps_list("tokenizers")
|
||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"]
|
||||
extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"]
|
||||
extras["modelcreation"] = deps_list("cookiecutter")
|
||||
|
||||
extras["sagemaker"] = deps_list("sagemaker")
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
|
|||
|
||||
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
||||
"""
|
||||
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR
|
||||
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
|
||||
|
||||
Args:
|
||||
nlp: The pipeline to be exported
|
||||
|
|
@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
|||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
from keras2onnx import __version__ as k2ov
|
||||
from keras2onnx import convert_keras, save_model
|
||||
from tf2onnx import __version__ as t2ov
|
||||
from tf2onnx import convert_keras, save_model
|
||||
|
||||
print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}")
|
||||
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
|
||||
|
||||
# Build
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ deps = {
|
|||
"jax": "jax>=0.2.8",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
"jieba": "jieba",
|
||||
"keras2onnx": "keras2onnx",
|
||||
"nltk": "nltk",
|
||||
"numpy": "numpy>=1.17",
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
|
|
@ -57,6 +56,7 @@ deps = {
|
|||
"starlette": "starlette",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
||||
"tensorflow": "tensorflow>=2.3",
|
||||
"tf2onnx": "tf2onnx",
|
||||
"timeout-decorator": "timeout-decorator",
|
||||
"timm": "timm",
|
||||
"tokenizers": "tokenizers>=0.10.1",
|
||||
|
|
|
|||
|
|
@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError:
|
|||
_sympy_available = False
|
||||
|
||||
|
||||
_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
|
||||
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
|
||||
try:
|
||||
_keras2onnx_version = importlib_metadata.version("keras2onnx")
|
||||
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
|
||||
_tf2onnx_version = importlib_metadata.version("tf2onnx")
|
||||
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_keras2onnx_available = False
|
||||
_tf2onnx_available = False
|
||||
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
try:
|
||||
|
|
@ -429,8 +429,8 @@ def is_coloredlogs_available():
|
|||
return _coloredlogs_available
|
||||
|
||||
|
||||
def is_keras2onnx_available():
|
||||
return _keras2onnx_available
|
||||
def is_tf2onnx_available():
|
||||
return _tf2onnx_available
|
||||
|
||||
|
||||
def is_onnx_available():
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from .file_utils import (
|
|||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_keras2onnx_available,
|
||||
is_librosa_available,
|
||||
is_onnx_available,
|
||||
is_pandas_available,
|
||||
|
|
@ -49,6 +48,7 @@ from .file_utils import (
|
|||
is_soundfile_availble,
|
||||
is_spacy_available,
|
||||
is_tensorflow_probability_available,
|
||||
is_tf2onnx_available,
|
||||
is_tf_available,
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
|
|
@ -246,9 +246,9 @@ def require_rjieba(test_case):
|
|||
return test_case
|
||||
|
||||
|
||||
def require_keras2onnx(test_case):
|
||||
if not is_keras2onnx_available():
|
||||
return unittest.skip("test requires keras2onnx")(test_case)
|
||||
def require_tf2onnx(test_case):
|
||||
if not is_tf2onnx_available():
|
||||
return unittest.skip("test requires tf2onnx")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ from transformers.testing_utils import (
|
|||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_keras2onnx,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
|
@ -254,14 +254,14 @@ class TFModelTesterMixin:
|
|||
|
||||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||
|
||||
@require_keras2onnx
|
||||
@require_tf2onnx
|
||||
@slow
|
||||
def test_onnx_runtime_optimize(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
import keras2onnx
|
||||
import onnxruntime
|
||||
import tf2onnx
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
@ -269,9 +269,9 @@ class TFModelTesterMixin:
|
|||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
|
||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||
|
||||
onnxruntime.InferenceSession(onnx_model.SerializeToString())
|
||||
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
|
||||
|
||||
def test_keras_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
|
|
|||
Loading…
Reference in a new issue