mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Allow sequence length to be symbolic (#2559)
This commit is contained in:
parent
73c682b97c
commit
038ee91da5
1 changed files with 2 additions and 6 deletions
|
|
@ -241,14 +241,12 @@ will be calculated.)DOC";
|
|||
auto& input_ids_shape = getInputShape(ctx, 0);
|
||||
auto& input_ids_dims = input_ids_shape.dim();
|
||||
|
||||
// Note that both batch size and sequence length could be symbolic.
|
||||
// So we only check dimension size here.
|
||||
if (input_ids_dims.size() != 2) {
|
||||
fail_shape_inference("Inputs 0 shall be 2 dimensions");
|
||||
}
|
||||
|
||||
if (!input_ids_dims[1].has_dim_value()) {
|
||||
fail_shape_inference("Inputs 0 shall have value in dimension 1");
|
||||
}
|
||||
|
||||
// get hidden_size from the last dimension of embedding
|
||||
auto& word_embedding_shape = getInputShape(ctx, 3);
|
||||
auto& word_embedding_dims = word_embedding_shape.dim();
|
||||
|
|
@ -259,8 +257,6 @@ will be calculated.)DOC";
|
|||
}
|
||||
int64_t hidden_size = word_embedding_shape.dim(1).dim_value();
|
||||
|
||||
|
||||
|
||||
// input shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size)
|
||||
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
||||
for (auto& dim : input_ids_dims) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue