From daf87fd0dd3eaebb3ac9917797342770b9d6fbb0 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 21 Apr 2022 21:01:40 -0700 Subject: [PATCH] specify the path for gpt2_helper in onnx_exporter.py (#11301) --- onnxruntime/python/tools/transformers/onnx_exporter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 62cf254462..41aaea056d 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -12,11 +12,15 @@ from pathlib import Path from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig from affinity_helper import AffinitySetting from benchmark_helper import create_onnxruntime_session, Precision, OptimizerInfo -from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState from quantize_helper import QuantizeHelper from huggingface_models import MODEL_CLASSES from torch_onnx_export_helper import torch_onnx_export +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), 'models', 'gpt2')) +from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' logger = logging.getLogger(__name__)