mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Support missing optional attribute in Squeeze operator (#1505)
* Make Squeeze operator support no axes attribute cases * Fix build break * Resolve PR comments and exclude tensorrt for the new tests
This commit is contained in:
parent
717e764e8e
commit
6f538dc861
2 changed files with 28 additions and 8 deletions
|
|
@ -13,13 +13,15 @@ class SqueezeBase {
|
|||
protected:
|
||||
explicit SqueezeBase(const OpKernelInfo& info) {
|
||||
std::vector<int64_t> axes;
|
||||
Status status = info.GetAttrs<int64_t>("axes", axes);
|
||||
ORT_ENFORCE(status.IsOK(), "Attribute axes is not set.");
|
||||
// Parse attribute 'axes'
|
||||
Status status = info.GetAttrs<int64_t>("axes", axes);
|
||||
|
||||
// Handle out of order and repeating dims.
|
||||
std::sort(axes.begin(), axes.end());
|
||||
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
|
||||
axes_ = axes;
|
||||
// Handle out of order and repeating dims when 'axes' exists.
|
||||
if (status.IsOK()) {
|
||||
std::sort(axes.begin(), axes.end());
|
||||
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
|
||||
axes_ = axes;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int64_t> ComputeOutputShape(
|
||||
|
|
@ -28,7 +30,8 @@ class SqueezeBase {
|
|||
size_t j = 0;
|
||||
std::vector<int64_t> output_shape;
|
||||
for (size_t i = 0; i < input_shape.NumDimensions(); ++i) {
|
||||
if (j < axes.NumDimensions() && axes[j] == static_cast<int64_t>(i)) {
|
||||
if ((j < axes.NumDimensions() && axes[j] == static_cast<int64_t>(i)) ||
|
||||
(axes.NumDimensions() == 0 && input_shape[i] == 1)) {
|
||||
ORT_ENFORCE(input_shape[i] == 1, "Dimension of input ", i, " must be 1 instead of ", input_shape[i],
|
||||
". shape=", input_shape);
|
||||
++j;
|
||||
|
|
@ -59,4 +62,4 @@ class Squeeze final : public OpKernel, public SqueezeBase {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -18,6 +18,23 @@ TEST(SqueezeOpTest, Squeeze_1) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SqueezeOpTest, Squeeze_Empty_Axes_1) {
|
||||
OpTester test("Squeeze");
|
||||
test.AddInput<float>("data", {1, 1, 4, 1}, std::vector<float>(4, 1.0f));
|
||||
test.AddOutput<float>("squeezed", {4}, std::vector<float>(4, 1.0f));
|
||||
// TensorRT doesn't seem to support missing 'axes'
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(SqueezeOpTest, Squeeze_Empty_Axes_2) {
|
||||
OpTester test("Squeeze");
|
||||
// nothing to "squeeze" out in the input shape
|
||||
test.AddInput<float>("data", {2, 4}, std::vector<float>(8, 1.0f));
|
||||
test.AddOutput<float>("squeezed", {2, 4}, std::vector<float>(8, 1.0f));
|
||||
// TensorRT doesn't seem to support missing 'axes'
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(SqueezeOpTest, Squeeze_1_int32) {
|
||||
OpTester test("Squeeze");
|
||||
test.AddAttribute("axes", std::vector<int64_t>{0});
|
||||
|
|
|
|||
Loading…
Reference in a new issue