From 038ee91da52e9167dd35a487cf289893a3b2ee0f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 6 Dec 2019 10:13:56 -0800 Subject: [PATCH] Allow sequence length to be symbolic (#2559) --- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6b1cd16c84..2db0c69176 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -241,14 +241,12 @@ will be calculated.)DOC"; auto& input_ids_shape = getInputShape(ctx, 0); auto& input_ids_dims = input_ids_shape.dim(); + // Note that both batch size and sequence length could be symbolic. + // So we only check dimension size here. if (input_ids_dims.size() != 2) { fail_shape_inference("Inputs 0 shall be 2 dimensions"); } - if (!input_ids_dims[1].has_dim_value()) { - fail_shape_inference("Inputs 0 shall have value in dimension 1"); - } - // get hidden_size from the last dimension of embedding auto& word_embedding_shape = getInputShape(ctx, 3); auto& word_embedding_dims = word_embedding_shape.dim(); @@ -259,8 +257,6 @@ will be calculated.)DOC"; } int64_t hidden_size = word_embedding_shape.dim(1).dim_value(); - - // input shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size) ONNX_NAMESPACE::TensorShapeProto output_shape; for (auto& dim : input_ids_dims) {