Fixed bug in Flatten's axis (#18645)

Flatten's axis is in the range [-r, r] rather than [-r, r-1].
This commit is contained in:
Wanming Lin 2023-12-01 08:19:22 +08:00 committed by GitHub
parent 6781b6cf3d
commit 73a2eb82eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -36,7 +36,11 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
int64_t rank = input_shape.size();
NodeAttrHelper helper(node);
int64_t axis = helper.Get("axis", 1);
axis = HandleNegativeAxis(axis, rank);
ORT_ENFORCE(axis >= -rank && axis <= rank, "axis ", axis,
" is not in valid range [-", rank, ",", rank, "]");
if (axis < 0) {
axis += rank;
}
// Use WebNN's reshape to implement Flatten.
int64_t num_pre_axis_elements = std::accumulate(