CoreML: Add GridSample ML Program support (#21431)

### Description
<!-- Describe your changes. -->
Add GridSample ML Program support

One combination of inputs has diffs between the pytorch generated unit
tests data and CoreML. Disabling until needed as investigation may take
a while.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
High priorities models
This commit is contained in:
Scott McKay 2024-07-24 11:04:48 +10:00 committed by GitHub
parent 86cedc6832
commit 1df9aa2f08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 191 additions and 50 deletions

View file

@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/shared/utils/utils.h"
namespace onnxruntime {
namespace coreml {
namespace {
std::string_view GetMode(const NodeAttrHelper& helper) {
// opset 16 used bilinear, nearest, bicubic
// opset 20+ uses linear, nearest, cubic
// bilinear is what CoreML uses, so prefer that
// bicubic/cubic isn't supported
const auto& mode = helper.Get("mode", "linear");
if (mode == "linear") {
return "bilinear";
}
return mode;
}
} // namespace
class GridSampleOpBuilder : public BaseOpBuilder {
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override;
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;
bool SupportsMLProgram() const override { return true; }
};
Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder,
[[maybe_unused]] const Node& node,
[[maybe_unused]] const logging::Logger& logger) const {
#if defined(COREML_ENABLE_MLPROGRAM)
using namespace CoreML::Specification::MILSpec; // NOLINT
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resample
const auto input_defs = node.InputDefs();
const auto output_defs = node.OutputDefs();
NodeAttrHelper helper(node);
std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant
std::string padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);
const std::string coordinates_mode = "normalized_minus_one_to_one";
// adjust to coreml equivalents
if (padding_mode == "zeros") {
padding_mode = "constant";
}
auto op = model_builder.CreateOperation(node, "resample");
AddOperationInput(*op, "x", input_defs[0]->Name());
AddOperationInput(*op, "coordinates", input_defs[1]->Name());
AddOperationInput(*op, "sampling_mode", model_builder.AddScalarConstant(op->type(), "sampling_mode", mode));
AddOperationInput(*op, "padding_mode", model_builder.AddScalarConstant(op->type(), "padding_mode", padding_mode));
AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f));
AddOperationInput(*op, "coordinates_mode",
model_builder.AddScalarConstant(op->type(), "coordinates_mode", coordinates_mode));
AddOperationInput(*op, "align_corners", model_builder.AddScalarConstant(op->type(), "align_corners", align_corners));
AddOperationOutput(*op, *output_defs[0]);
model_builder.AddOperation(std::move(op));
#endif
return Status::OK();
}
bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
if (!input_params.create_mlprogram) {
LOGS(logger, VERBOSE) << "GridSample is not supported.";
return false;
}
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, VERBOSE) << "GridSample: failed to get input shape";
return false;
}
const auto input_rank = input_shape.size();
if (input_rank != 4) {
LOGS(logger, VERBOSE) << "GridSample only supports 4D input. Got:" << input_rank << "D";
return false;
}
NodeAttrHelper helper(node);
std::string_view mode = GetMode(helper);
if (mode != "bilinear" && mode != "zeros") {
LOGS(logger, VERBOSE) << "GridSample does not support mode of " << mode;
return false;
}
// there is one combination of settings where the unit test fails.
// The ORT unit test values are generated by pytorch so not clear if it's an issue with CoreML.
// CoreML output is consistent for CPU and non-CPU at least.
// Disabling until there's a use-case that requires this combination.
const auto& padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);
if (mode == "bilinear" && padding_mode == "reflection" && align_corners == false) {
LOGS(logger, VERBOSE) << "GridSample does not support mode:" << mode << " padding_mode:" << padding_mode
<< " align_corners:" << align_corners
<< " currently due to output diffs that need to be investigated";
return false;
}
return true;
}
void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GridSampleOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
} // namespace coreml
} // namespace onnxruntime

View file

@ -130,6 +130,8 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateSplitOpBuilder("Split", op_registrations);
}
CreateGridSampleOpBuilder("GridSample", op_registrations);
return op_registrations;
}

View file

@ -28,6 +28,7 @@ void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrati
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

View file

@ -13,6 +13,7 @@ std::vector<std::unique_ptr<IExecutionProvider>> GetExecutionProviders(int opset
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
if (opset_version < 20) {
execution_providers.emplace_back(DefaultCudaExecutionProvider());
@ -20,8 +21,12 @@ std::vector<std::unique_ptr<IExecutionProvider>> GetExecutionProviders(int opset
execution_providers.push_back(DefaultCudaNHWCExecutionProvider());
#endif
}
#endif
#if defined(USE_COREML)
execution_providers.push_back(DefaultCoreMLExecutionProvider(/*use_mlprogram*/ true));
#endif
return execution_providers;
}
@ -35,7 +40,7 @@ void RunTests(T& test, std::vector<std::unique_ptr<IExecutionProvider>>&& execut
// DO NOT edit following tests. They are generated by:
// onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "nearest";
std::string padding_mode = "zeros";
@ -55,7 +60,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "nearest";
std::string padding_mode = "zeros";
@ -75,7 +80,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "nearest";
std::string padding_mode = "border";
@ -95,7 +100,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "nearest";
std::string padding_mode = "border";
@ -115,7 +120,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "nearest";
std::string padding_mode = "reflection";
@ -135,7 +140,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "nearest";
std::string padding_mode = "reflection";
@ -155,7 +160,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners)
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bilinear";
std::string padding_mode = "zeros";
@ -175,7 +180,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bilinear";
std::string padding_mode = "zeros";
@ -195,7 +200,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bilinear";
std::string padding_mode = "border";
@ -215,7 +220,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bilinear";
std::string padding_mode = "border";
@ -235,7 +240,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bilinear";
std::string padding_mode = "reflection";
@ -255,7 +260,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bilinear";
std::string padding_mode = "reflection";
@ -275,7 +280,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bicubic";
std::string padding_mode = "zeros";
@ -295,7 +300,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bicubic";
std::string padding_mode = "zeros";
@ -315,7 +320,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bicubic";
std::string padding_mode = "border";
@ -335,7 +340,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bicubic";
std::string padding_mode = "border";
@ -355,7 +360,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bicubic";
std::string padding_mode = "reflection";
@ -375,7 +380,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) {
OpTester test("GridSample", 16);
std::string mode = "bicubic";
std::string padding_mode = "reflection";
@ -395,7 +400,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners)
RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "zeros";
@ -415,7 +420,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "zeros";
@ -435,7 +440,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "zeros";
@ -455,7 +460,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "zeros";
@ -475,7 +480,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "border";
@ -495,7 +500,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "border";
@ -515,7 +520,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "border";
@ -535,7 +540,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "border";
@ -555,7 +560,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "reflection";
@ -575,7 +580,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "reflection";
@ -595,7 +600,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "reflection";
@ -615,7 +620,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners)
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "nearest";
std::string padding_mode = "reflection";
@ -635,7 +640,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners)
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "zeros";
@ -655,7 +660,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "zeros";
@ -675,7 +680,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "zeros";
@ -695,7 +700,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "zeros";
@ -715,7 +720,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "border";
@ -735,7 +740,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "border";
@ -755,7 +760,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "border";
@ -775,7 +780,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "border";
@ -795,7 +800,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "reflection";
@ -815,7 +820,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "reflection";
@ -835,7 +840,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "reflection";
@ -855,7 +860,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "reflection";
@ -875,7 +880,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "cubic";
std::string padding_mode = "zeros";
@ -895,7 +900,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "cubic";
std::string padding_mode = "zeros";
@ -915,7 +920,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "cubic";
std::string padding_mode = "border";
@ -935,7 +940,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "cubic";
std::string padding_mode = "border";
@ -955,7 +960,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "cubic";
std::string padding_mode = "reflection";
@ -975,7 +980,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) {
TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "cubic";
std::string padding_mode = "reflection";

View file

@ -58,7 +58,7 @@ for opset_version in [16, 20]:
onnx_align_corners = 1 if align_corners else 0
test_name = f"test_grid_sample_{opset_version}_{ndim}D_{mode}_{padding_mode}_{'align_corners' if align_corners else 'no_align_corners'}"
print(f"TEST(GridsampleTest, {test_name}) {{")
print(f"TEST(GridSampleTest, {test_name}) {{")
print(f'OpTester test("GridSample", {opset_version});')
print(f'std::string mode = "{onnx_mode}";')
print(f'std::string padding_mode = "{padding_mode}";')

View file

@ -11,6 +11,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
|ai.onnx:Gemm|Input B must be constant.|
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|ai.onnx:GridSample|4D input.<br/>'mode' of 'linear' or 'zeros'.<br/>(mode==linear && padding_mode==reflection && align_corners==0) is not supported.|
|ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.|
|ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|ai.onnx:Mul||