Allow sequence length to be symbolic (#2559)

This commit is contained in:
Tianlei Wu 2019-12-06 10:13:56 -08:00 committed by GitHub
parent 73c682b97c
commit 038ee91da5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -241,14 +241,12 @@ will be calculated.)DOC";
auto& input_ids_shape = getInputShape(ctx, 0);
auto& input_ids_dims = input_ids_shape.dim();
// Note that both batch size and sequence length could be symbolic.
// So we only check dimension size here.
if (input_ids_dims.size() != 2) {
fail_shape_inference("Inputs 0 shall be 2 dimensions");
}
if (!input_ids_dims[1].has_dim_value()) {
fail_shape_inference("Inputs 0 shall have value in dimension 1");
}
// get hidden_size from the last dimension of embedding
auto& word_embedding_shape = getInputShape(ctx, 3);
auto& word_embedding_dims = word_embedding_shape.dim();
@ -259,8 +257,6 @@ will be calculated.)DOC";
}
int64_t hidden_size = word_embedding_shape.dim(1).dim_value();
// input shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size)
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (auto& dim : input_ids_dims) {