mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Update BeamSearch operator spec to support t5 (#10777)
* change BeamSearch op to support encoder decoder model * check model_type and decoder attribute * fix * update comments * warn shape inference issue with onnx v1.11 or T5 * skip parity test when tempature != 1.0 * fix build
This commit is contained in:
parent
6be5185088
commit
0e335aba37
11 changed files with 150 additions and 57 deletions
|
|
@ -350,12 +350,16 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>body</tt> : graph (required)</dt>
|
||||
<dd>The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output</dd>
|
||||
<dt><tt>decoder</tt> : graph (required)</dt>
|
||||
<dd>Decoder subgraph to execute in a loop.</dd>
|
||||
<dt><tt>early_stopping</tt> : int</dt>
|
||||
<dd>early stop or not</dd>
|
||||
<dt><tt>encoder_decoder_init</tt> : graph</dt>
|
||||
<dd>subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
|
||||
<dt><tt>eos_token_id</tt> : int (required)</dt>
|
||||
<dd>The id of the end-of-sequence token</dd>
|
||||
<dt><tt>model_type</tt> : int</dt>
|
||||
<dd>model type: 0 for GPT-2; 1 for encoder decoder like T5</dd>
|
||||
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
|
||||
<dd>no repeat ngrams size</dd>
|
||||
<dt><tt>pad_token_id</tt> : int (required)</dt>
|
||||
|
|
|
|||
|
|
@ -256,35 +256,41 @@ class BeamSearchImpl {
|
|||
};
|
||||
|
||||
void BeamSearch::Init(const OpKernelInfo& info) {
|
||||
// Make sure the body attribute was present even though we don't need it here.
|
||||
// Make sure the decoder attribute was present even though we don't need it here.
|
||||
ONNX_NAMESPACE::GraphProto proto;
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("body", &proto).IsOK());
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("decoder", &proto).IsOK());
|
||||
ORT_IGNORE_RETURN_VALUE(proto);
|
||||
|
||||
parameters_.ParseFromAttributes(info);
|
||||
|
||||
cuda_stream_ = nullptr;
|
||||
}
|
||||
|
||||
Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
||||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
const auto& node = Node();
|
||||
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
|
||||
gpt_subgraph_->num_heads,
|
||||
gpt_subgraph_->head_size,
|
||||
gpt_subgraph_->num_layers);
|
||||
// TODO: handle another subgraph with attribute name "encoder_decode_init"
|
||||
if (attribute_name == "decoder") {
|
||||
const auto& node = Node();
|
||||
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
|
||||
gpt_subgraph_->num_heads,
|
||||
gpt_subgraph_->head_size,
|
||||
gpt_subgraph_->num_layers);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
||||
if (parameters_.model_type != 0) {
|
||||
// TODO: support encoder decoder model like T5
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented");
|
||||
}
|
||||
|
||||
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto* session_state = ctx_internal->SubgraphSessionState("body");
|
||||
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");
|
||||
auto* session_state = ctx_internal->SubgraphSessionState("decoder");
|
||||
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
|
|||
|
||||
class BeamSearch : public IControlFlowKernel {
|
||||
public:
|
||||
BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info), cuda_stream_(nullptr), dumper_(nullptr) {
|
||||
BeamSearch(const OpKernelInfo& info)
|
||||
: IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
|
||||
Init(info);
|
||||
}
|
||||
|
||||
|
|
@ -87,7 +88,7 @@ class BeamSearch : public IControlFlowKernel {
|
|||
void* cuda_stream_;
|
||||
|
||||
IConsoleDumper* dumper_;
|
||||
|
||||
|
||||
BeamSearchParameters parameters_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ Status BeamSearchParameters::Validate() const {
|
|||
}
|
||||
|
||||
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
|
||||
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", 0));
|
||||
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
|
||||
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
|
||||
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class IBeamScorer {
|
|||
|
||||
struct IBeamSearchParameters {
|
||||
// Parameters from node attributes
|
||||
int model_type;
|
||||
int eos_token_id;
|
||||
int pad_token_id;
|
||||
int no_repeat_ngram_size;
|
||||
|
|
|
|||
|
|
@ -457,9 +457,8 @@ Status UpdateFeeds(
|
|||
for (size_t i = 1; i < last_outputs.size(); ++i) {
|
||||
next_inputs[i + 2] = last_outputs[i];
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
return PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream);
|
||||
ORT_RETURN_IF_ERROR(PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream));
|
||||
}
|
||||
|
||||
// Make sure data is ready before next subgraph execution.
|
||||
|
|
|
|||
|
|
@ -919,10 +919,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
|
|||
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
|
||||
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"body",
|
||||
"The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output",
|
||||
AttributeProto::GRAPH)
|
||||
.Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
|
||||
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
|
||||
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
|
||||
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
|
||||
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
|
||||
|
|
@ -951,8 +950,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
|
|||
BeamSearchShapeInference(ctx);
|
||||
}));
|
||||
|
||||
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1,
|
||||
OpSchema()
|
||||
.Input(0, "X", "input", "T")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,18 @@
|
|||
#-------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#-------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
This converts GPT2 or T5 model to onnx with beam search operator.
|
||||
|
||||
Example 1: convert gpt2 model with beam search:
|
||||
python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
|
||||
|
||||
Example 2: convert T5 model with beam search:
|
||||
python ./models/t5/convert_to_onnx.py -m t5-small -s
|
||||
python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
|
@ -9,21 +22,17 @@ import argparse
|
|||
from pathlib import Path
|
||||
from onnx import helper
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
import torch
|
||||
from transformers import GPT2Config
|
||||
from packaging import version
|
||||
from transformers import GPT2Config, T5Config
|
||||
from gpt2_helper import PRETRAINED_GPT2_MODELS
|
||||
from convert_to_onnx import main as convert_gpt2_to_onnx
|
||||
from benchmark_helper import Precision
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
"""
|
||||
This converts GPT2 model to onnx with beam search operator.
|
||||
|
||||
Examples:
|
||||
python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
|
||||
"""
|
||||
|
||||
config: GPT2Config = None
|
||||
config: Union[GPT2Config, T5Config] = None
|
||||
|
||||
logger = logging.getLogger('')
|
||||
|
||||
|
|
@ -37,16 +46,29 @@ def parse_arguments(argv=None):
|
|||
type=str,
|
||||
help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS))
|
||||
|
||||
parser.add_argument('--model_type',
|
||||
required=False,
|
||||
type=str,
|
||||
default="gpt2",
|
||||
choices=["gpt2", "t5"],
|
||||
help='Model type in the list: ' + ', '.join(["gpt2", "t5"]))
|
||||
|
||||
parser.add_argument('--cache_dir',
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join('.', 'cache_models'),
|
||||
help='Directory to cache pre-trained models')
|
||||
|
||||
parser.add_argument('--gpt2_onnx',
|
||||
parser.add_argument('--decoder_onnx',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Output directory for GPT-2 onnx model, or model path ends with .onnx')
|
||||
help='Output directory for decoder onnx model, or model path ends with .onnx')
|
||||
|
||||
parser.add_argument('--encoder_decoder_init_onnx',
|
||||
required=False,
|
||||
type=str,
|
||||
default="",
|
||||
help='path of ONNX model for encoder and decoder initialization. Required for t5 model type.')
|
||||
|
||||
parser.add_argument('--output',
|
||||
required=False,
|
||||
|
|
@ -153,12 +175,12 @@ def parse_arguments(argv=None):
|
|||
def gpt2_to_onnx(args):
|
||||
model_name = args.model_name_or_path
|
||||
|
||||
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...")
|
||||
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...")
|
||||
arguments = [
|
||||
'--model_name_or_path',
|
||||
model_name,
|
||||
'--output',
|
||||
args.gpt2_onnx,
|
||||
args.decoder_onnx,
|
||||
'--optimize_onnx',
|
||||
'--precision',
|
||||
'fp32' if args.precision == Precision.FLOAT32 else 'fp16',
|
||||
|
|
@ -182,13 +204,16 @@ def gpt2_to_onnx(args):
|
|||
convert_gpt2_to_onnx(arguments)
|
||||
|
||||
|
||||
def shape_inference(gpt2_onnx_path):
|
||||
def shape_inference(decoder_onnx_path):
|
||||
if version.parse(onnx.__version__) >= version.parse('1.11.0'):
|
||||
logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.")
|
||||
|
||||
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False)
|
||||
out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False)
|
||||
if out:
|
||||
# TODO: Use external format if input has extra data.
|
||||
onnx.save(out, gpt2_onnx_path)
|
||||
onnx.save(out, decoder_onnx_path)
|
||||
else:
|
||||
print("Failed to run symbolic shape inference on the model.")
|
||||
|
||||
|
|
@ -224,7 +249,7 @@ def verify_gpt2_subgraph(graph, precision):
|
|||
|
||||
expected_type = onnx_proto.TensorProto.INT32
|
||||
if i >= 3:
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
|
||||
|
||||
if graph.input[i].type.tensor_type.elem_type != expected_type:
|
||||
raise ValueError(
|
||||
|
|
@ -240,7 +265,7 @@ def verify_gpt2_subgraph(graph, precision):
|
|||
if graph.output[i].name != expected_output:
|
||||
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT32
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
|
||||
if graph.output[i].type.tensor_type.elem_type != expected_type:
|
||||
raise ValueError(
|
||||
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}"
|
||||
|
|
@ -251,17 +276,35 @@ def verify_gpt2_subgraph(graph, precision):
|
|||
return
|
||||
|
||||
|
||||
def verify_t5_decoder_subgraph(graph, precision):
|
||||
# TODO: implement it
|
||||
pass
|
||||
|
||||
|
||||
def verify_t5_encoder_decoder_init_subgraph(graph, precision):
|
||||
# TODO: implement it
|
||||
pass
|
||||
|
||||
|
||||
def convert_model(args):
|
||||
if os.path.exists(args.gpt2_onnx):
|
||||
print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}")
|
||||
if os.path.exists(args.decoder_onnx):
|
||||
print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
|
||||
else:
|
||||
assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2"
|
||||
gpt2_to_onnx(args)
|
||||
|
||||
print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.")
|
||||
shape_inference(args.gpt2_onnx)
|
||||
|
||||
# TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
|
||||
enable_shape_inference = args.model_type == "gpt2"
|
||||
|
||||
if enable_shape_inference:
|
||||
print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
|
||||
shape_inference(args.decoder_onnx)
|
||||
|
||||
global config
|
||||
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
if args.model_type == "gpt2":
|
||||
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
else:
|
||||
config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
print(config)
|
||||
|
||||
eos_token_id = config.eos_token_id
|
||||
|
|
@ -272,10 +315,14 @@ def convert_model(args):
|
|||
if args.vocab_size != -1:
|
||||
vocab_size = args.vocab_size
|
||||
|
||||
model = onnx.load(args.gpt2_onnx)
|
||||
verify_gpt2_subgraph(model.graph, args.precision)
|
||||
model = onnx.load(args.decoder_onnx)
|
||||
model.graph.name = f"{args.model_type} decoder subgraph"
|
||||
|
||||
if args.model_type == "gpt2":
|
||||
verify_gpt2_subgraph(model.graph, args.precision)
|
||||
else:
|
||||
verify_t5_decoder_subgraph(model.graph, args.precision)
|
||||
|
||||
model.graph.name = "gpt2 subgraph"
|
||||
inputs = [
|
||||
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
|
||||
"repetition_penalty", "vocab_mask"
|
||||
|
|
@ -291,16 +338,28 @@ def convert_model(args):
|
|||
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
|
||||
outputs.append("scores")
|
||||
|
||||
node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2')
|
||||
node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name=f'BeamSearch_{args.model_type}')
|
||||
node.domain = "com.microsoft"
|
||||
node.attribute.extend([
|
||||
helper.make_attribute("eos_token_id", eos_token_id),
|
||||
helper.make_attribute("pad_token_id", pad_token_id),
|
||||
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
|
||||
helper.make_attribute("body", model.graph),
|
||||
helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
|
||||
helper.make_attribute("decoder", model.graph),
|
||||
])
|
||||
|
||||
if args.model_type == "t5":
|
||||
if enable_shape_inference:
|
||||
print(f"Run symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
|
||||
shape_inference(args.encoder_decoder_init_onnx)
|
||||
init_model = onnx.load(args.encoder_decoder_init_onnx)
|
||||
init_model.graph.name = f"{args.model_type} encoder decoder init subgraph"
|
||||
verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision)
|
||||
node.attribute.extend([
|
||||
helper.make_attribute("encoder_decoder_init", init_model.graph),
|
||||
])
|
||||
|
||||
from onnx import TensorProto
|
||||
|
||||
# graph inputs
|
||||
|
|
@ -344,7 +403,7 @@ def convert_model(args):
|
|||
if args.output_token_scores:
|
||||
graph_outputs.append(scores)
|
||||
|
||||
new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers)
|
||||
new_graph = helper.make_graph([node], f'{args.model_type}-beam-search', graph_inputs, graph_outputs, initializers)
|
||||
|
||||
# Create the model
|
||||
new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import)
|
||||
|
|
@ -392,10 +451,20 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id,
|
|||
|
||||
|
||||
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
||||
if args.model_type != "gpt2":
|
||||
print(
|
||||
f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime"
|
||||
)
|
||||
return True
|
||||
|
||||
if args.temperature != 1.0:
|
||||
# TODO: implement temperature in BeamSearch operator.
|
||||
print("Skipping parity test as temperature is not implemented in BeamSearch operator")
|
||||
return True
|
||||
|
||||
if args.prefix_vocab_mask:
|
||||
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
|
||||
return
|
||||
return True
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
|
|
@ -547,6 +616,8 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
|||
|
||||
def main(argv=None, sentences=None):
|
||||
args = parse_arguments(argv)
|
||||
if args.model_type == "t5":
|
||||
assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool"
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print(f"skip conversion since path existed: {args.output}")
|
||||
|
|
|
|||
8
onnxruntime/python/tools/transformers/models/__init__.py
Normal file
8
onnxruntime/python/tools/transformers/models/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#-------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#-------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
|
@ -1,3 +1,7 @@
|
|||
#-------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#-------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -18,12 +18,13 @@ else:
|
|||
|
||||
|
||||
class TestBeamSearch(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
#TODO: use a smaller model and enable tests in CI pipeline
|
||||
self.model_name = "gpt2"
|
||||
self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx')
|
||||
self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx')
|
||||
self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
|
||||
self.cpu_params = f'-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
|
||||
|
||||
def run_beam_search(self, arguments: str, sentences=None):
|
||||
return run(arguments.split(), sentences=sentences)
|
||||
|
|
|
|||
Loading…
Reference in a new issue