diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 62cf254462..41aaea056d 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -12,11 +12,15 @@ from pathlib import Path from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig from affinity_helper import AffinitySetting from benchmark_helper import create_onnxruntime_session, Precision, OptimizerInfo -from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState from quantize_helper import QuantizeHelper from huggingface_models import MODEL_CLASSES from torch_onnx_export_helper import torch_onnx_export +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), 'models', 'gpt2')) +from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' logger = logging.getLogger(__name__)