Support missing optional attribute in Squeeze operator (#1505)

* Make Squeeze operator support no axes attribute cases

* Fix build break

* Resolve PR comments and exclude tensorrt for the new tests
This commit is contained in:
Hariharan Seshadri 2019-07-26 11:16:35 -07:00 committed by GitHub
parent 717e764e8e
commit 6f538dc861
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 8 deletions

View file

@ -13,13 +13,15 @@ class SqueezeBase {
protected:
explicit SqueezeBase(const OpKernelInfo& info) {
std::vector<int64_t> axes;
Status status = info.GetAttrs<int64_t>("axes", axes);
ORT_ENFORCE(status.IsOK(), "Attribute axes is not set.");
// Parse attribute 'axes'
Status status = info.GetAttrs<int64_t>("axes", axes);
// Handle out of order and repeating dims.
std::sort(axes.begin(), axes.end());
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
axes_ = axes;
// Handle out of order and repeating dims when 'axes' exists.
if (status.IsOK()) {
std::sort(axes.begin(), axes.end());
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
axes_ = axes;
}
}
static std::vector<int64_t> ComputeOutputShape(
@ -28,7 +30,8 @@ class SqueezeBase {
size_t j = 0;
std::vector<int64_t> output_shape;
for (size_t i = 0; i < input_shape.NumDimensions(); ++i) {
if (j < axes.NumDimensions() && axes[j] == static_cast<int64_t>(i)) {
if ((j < axes.NumDimensions() && axes[j] == static_cast<int64_t>(i)) ||
(axes.NumDimensions() == 0 && input_shape[i] == 1)) {
ORT_ENFORCE(input_shape[i] == 1, "Dimension of input ", i, " must be 1 instead of ", input_shape[i],
". shape=", input_shape);
++j;
@ -59,4 +62,4 @@ class Squeeze final : public OpKernel, public SqueezeBase {
}
};
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -18,6 +18,23 @@ TEST(SqueezeOpTest, Squeeze_1) {
test.Run();
}
TEST(SqueezeOpTest, Squeeze_Empty_Axes_1) {
OpTester test("Squeeze");
test.AddInput<float>("data", {1, 1, 4, 1}, std::vector<float>(4, 1.0f));
test.AddOutput<float>("squeezed", {4}, std::vector<float>(4, 1.0f));
// TensorRT doesn't seem to support missing 'axes'
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
TEST(SqueezeOpTest, Squeeze_Empty_Axes_2) {
OpTester test("Squeeze");
// nothing to "squeeze" out in the input shape
test.AddInput<float>("data", {2, 4}, std::vector<float>(8, 1.0f));
test.AddOutput<float>("squeezed", {2, 4}, std::vector<float>(8, 1.0f));
// TensorRT doesn't seem to support missing 'axes'
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
TEST(SqueezeOpTest, Squeeze_1_int32) {
OpTester test("Squeeze");
test.AddAttribute("axes", std::vector<int64_t>{0});