Introduce shrunken gather operator (#15396)

### Introduce shrunken gather operator

Exist Gather operator schema won't guarantee output element count will
be smaller than input element count.
Actually, it is possible output element count >, =, or < input element
count.

For some cases we know for sure output element count MUST be <= input
element count, we will upstream those Gather operators to reduce compute
flops.

So this PR introduces an ShrunkenGather which explicitly guarantee
output count will be smaller than input count. The operator add
additional restriction on inputs, but still re-use existing Gather's
implementations plus input check during runtime.

This is a requirement for subsequent optimization (Draft PR:
https://github.com/microsoft/onnxruntime/pull/15401) we will do for
label sparsity and embedding sparsity.
This commit is contained in:
pengwa 2023-04-07 15:12:58 +08:00 committed by GitHub
parent d31dd5935a
commit 16f5909f2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 415 additions and 12 deletions

View file

@ -86,7 +86,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNH
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalAveragePool);
#endif
// This section includes all op kernel declarations for former experimental ops which have now been removed from onnx.
// To maintain backward compatibility these are added as contrib ops.
// Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops
@ -122,6 +121,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen);
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ShrunkenGather);
#endif
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
@ -151,7 +156,6 @@ Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
return Status::OK();
}
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -171,8 +175,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
}
#endif
Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
@ -283,6 +285,13 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
#endif
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(
ShrunkenGather,
kMSDomain,
1,
kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind",
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
ShrunkenGather);
void ShrunkenGatherCommon::CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const {
const auto& input_shape = input_tensor->Shape();
const auto& indices_shape = indices_tensor->Shape();
ORT_ENFORCE(input_shape.NumDimensions() >= 1, "ShrunkenGather only support input with rank >= 1, got ",
input_shape.NumDimensions(), "-D input");
ORT_ENFORCE(indices_shape.NumDimensions() == 1, "ShrunkenGather only support 1D indices, got ",
indices_shape.NumDimensions(), "-D indices");
const auto input_rank = input_shape.NumDimensions();
auto axis = HandleNegativeAxis(axis_in, narrow<int64_t>(input_rank));
const int64_t N = indices_shape.Size();
const int64_t indices_max = input_shape[axis];
ORT_ENFORCE(indices_max >= N, "ShrunkenGather indices elem count should <= input dim on axis: ", axis,
", got indices elem count:", N, " input dim: ", indices_max);
}
Status ShrunkenGather::Compute(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
ShrunkenGatherCommon::CheckInput(p.input_tensor, p.indices_tensor, p.axis);
return Gather::Compute(context);
}
} // namespace contrib
} // namespace onnxruntime
#endif

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/tensor/gather.h"
namespace onnxruntime {
namespace contrib {
class ShrunkenGatherCommon {
public:
void CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const;
};
class ShrunkenGather final : public onnxruntime::Gather, public ShrunkenGatherCommon {
public:
ShrunkenGather(const OpKernelInfo& info) : Gather(info) {}
Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime
#endif

View file

@ -140,6 +140,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather);
#endif
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
@ -288,6 +294,12 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
#endif
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,

View file

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
#include "contrib_ops/cuda/tensor/shrunken_gather.h"
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
ONNX_OPERATOR_KERNEL_EX(
ShrunkenGather,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
ShrunkenGather);
Status ShrunkenGather::ComputeInternal(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
ShrunkenGatherCommon::CheckInput(p.input_tensor, p.indices_tensor, p.axis);
return onnxruntime::cuda::Gather::ComputeInternal(context);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/tensor/gather.h"
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
class ShrunkenGather final : public onnxruntime::cuda::Gather, public ShrunkenGatherCommon {
public:
ShrunkenGather(const OpKernelInfo& info) : onnxruntime::cuda::Gather(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif

View file

@ -116,6 +116,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen);
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather);
#endif
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather);
@ -247,6 +253,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
#endif
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather)>,

View file

@ -424,16 +424,14 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
auto& input_ids_dims = input_ids_shape.dim();
auto model_type_attr = ctx.getAttribute("model_type");
int64_t model_type = model_type_attr ? static_cast<int64_t>(model_type_attr->i()) : -1;
if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) {
if (input_ids_dims.size() != 3)
{
if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) {
if (input_ids_dims.size() != 3) {
fail_shape_inference("Inputs 0 shall be 3 dimensions in whisper graph");
}
}
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value() && input_ids_dims[2].has_dim_value())) {
return;
}
}
else if (input_ids_dims.size() != 2) {
} else if (input_ids_dims.size() != 2) {
fail_shape_inference("Inputs 0 shall be 2 dimensions", model_type);
}
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) {
@ -2722,6 +2720,89 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t
"Allow inputs and outputs to be any kind of tensor.");
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
static const char* ShrunkenGather_ver1_doc = R"DOC(
This op is a specialised case of Gather-13, adding additional constraint including: indices being 1D,
and indices count < input element count on the specified axis.
Having this op allows runtime to do operator re-ordering to reduce compute FLOPs.
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(ShrunkenGather)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc(ShrunkenGather_ver1_doc)
.AllowUncheckedAttributes()
.Attr(
"axis",
"Which axis to gather on. Negative value means "
"counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0, "data", "Tensor of rank r >= 1.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
"indices",
"Tensor of int64 indices, with rank = 1. All index values are expected to be within bounds [-s, s-1] "
"along axis of size s. It is an error if any of the index values are out of bounds."
"The number of elements in indices must be less than the number of elements in the input tensor,"
"which is the reason why this op is called ShrunkenGather.",
"Tind",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Output(0, "output", "Tensor of rank r.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint(
"T",
OpSchema::all_tensor_types_with_bfloat(),
"Constrain input and output types to any tensor type.")
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 2)) {
return;
}
const TensorShapeProto& data_shape = ctx.getInputType(0)->tensor_type().shape();
const TensorShapeProto& indices_shape = ctx.getInputType(1)->tensor_type().shape();
int r = data_shape.dim_size();
if (r < 1) {
fail_shape_inference("data tensor must have rank >= 1");
}
int q = indices_shape.dim_size();
int axis = static_cast<int>(getAttribute(ctx, "axis", 0));
if (axis < -r || axis >= r) {
fail_shape_inference("axis must be in [-r, r-1]");
}
if (axis < 0) {
axis += r;
}
int out_rank = q + r - 1;
auto final_output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
int i = 0;
for (; i < axis; ++i) {
*final_output_shape->add_dim() = data_shape.dim(i);
}
for (; i < axis + q; ++i) {
*final_output_shape->add_dim() = indices_shape.dim(i - axis);
}
for (; i < out_rank; ++i) {
*final_output_shape->add_dim() = data_shape.dim(i - q + 1);
}
});
#endif
#ifndef _OPSCHEMA_LIB_
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {

View file

@ -45,6 +45,10 @@
#include "orttraining/training_ops/cpu/tensor/split.h"
#include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h"
#include "orttraining/training_ops/cpu/optimizer/sgd/sgdbase.h"
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
#endif
#ifdef ENABLE_TRAINING
@ -275,6 +279,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
return p->SGDOptimizerV2Base::PrepareForCompute(ctx,
reinterpret_cast<contrib::SGDOptimizerV2Base::Prepare&>(prepare));
}
void contrib__ShrunkenGatherCommon__CheckInput(const contrib::ShrunkenGatherCommon* p, const Tensor* input_tensor,
const Tensor* indices_tensor, int64_t axis_in) const override {
return p->ShrunkenGatherCommon::CheckInput(input_tensor, indices_tensor, axis_in);
}
#endif
#ifdef ENABLE_TRAINING

View file

@ -194,6 +194,12 @@ struct ProviderHostCPU {
virtual Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, contrib__AdamWOptimizerBase__Prepare& prepare) = 0;
// From cpu/optimizer/sgdbase.h
virtual Status contrib__SGDOptimizerV2Base__PrepareForCompute(const contrib::SGDOptimizerV2Base* p, OpKernelContext* ctx, contrib__SGDOptimizerV2Base__Prepare& prepare) = 0;
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
virtual void contrib__ShrunkenGatherCommon__CheckInput(const contrib::ShrunkenGatherCommon* p,
const Tensor* input_tensor, const Tensor* indices_tensor,
int64_t axis_in) const = 0;
#endif
#ifdef ENABLE_TRAINING

View file

@ -10,7 +10,7 @@
namespace onnxruntime {
class Gather final : public OpKernel, public GatherBase {
class Gather : public OpKernel, public GatherBase {
public:
Gather(const OpKernelInfo& info) : OpKernel(info), GatherBase(info) {}

View file

@ -8,7 +8,7 @@
namespace onnxruntime {
namespace cuda {
class Gather final : public CudaKernel, public GatherBase {
class Gather : public CudaKernel, public GatherBase {
public:
Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;

View file

@ -175,6 +175,7 @@ class PassThrough;
class YieldOp;
class AdamWOptimizerBase;
class SGDOptimizerV2Base;
class ShrunkenGatherCommon;
} // namespace contrib
class UnsqueezeBase;

View file

@ -43,6 +43,10 @@
#include "orttraining/training_ops/cpu/controlflow/group.h"
#include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h"
#include "orttraining/training_ops/cpu/optimizer/sgd/sgdbase.h"
// Should remove the include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
#endif
#ifdef ENABLE_TRAINING
@ -647,6 +651,9 @@ Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx, AdamWOptimize
Status SGDOptimizerV2Base::PrepareForCompute(OpKernelContext* ctx, SGDOptimizerV2Base::Prepare& prepare) const {
return g_host_cpu.contrib__SGDOptimizerV2Base__PrepareForCompute(this, ctx, reinterpret_cast<contrib__SGDOptimizerV2Base__Prepare&>(prepare));
}
void ShrunkenGatherCommon::CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const {
return g_host_cpu.contrib__ShrunkenGatherCommon__CheckInput(this, input_tensor, indices_tensor, axis_in);
}
} // namespace contrib
#endif

View file

@ -423,5 +423,105 @@ TEST(GatherOpTest, Gather_axis1_neg_indices2d_int8) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Assertion `regionRanges != nullptr' failed
}
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
TEST(ShrunkenGatherOpTest, ShrunkenGather_PositiveAxis) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
execution_providers.emplace_back(DefaultCudaExecutionProvider());
#endif
#ifdef USE_ROCM
execution_providers.emplace_back(DefaultRocmExecutionProvider());
#endif
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("axis", 0LL);
test.AddInput<float>("data", {3, 4},
{0.0f, 1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f});
test.AddInput<int32_t>("indices", {2}, {1LL, 0LL});
test.AddOutput<float>("output", {2, 4}, {4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 1.0f, 2.0f, 3.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {},
nullptr,
&execution_providers);
}
TEST(ShrunkenGatherOpTest, ShrunkenGather_NegativeAxis) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
execution_providers.emplace_back(DefaultCudaExecutionProvider());
#endif
#ifdef USE_ROCM
execution_providers.emplace_back(DefaultRocmExecutionProvider());
#endif
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("axis", -1LL);
test.AddInput<float>("data", {3, 4},
{0.0f, 1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f});
test.AddInput<int32_t>("indices", {2}, {0LL, 3LL});
test.AddOutput<float>("output", {3, 2}, {0.0f, 3.0f, 4.0f, 7.0f, 8.0f, 11.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {},
nullptr,
&execution_providers);
}
TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidIndicesRank) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
execution_providers.emplace_back(DefaultCudaExecutionProvider());
#endif
#ifdef USE_ROCM
execution_providers.emplace_back(DefaultRocmExecutionProvider());
#endif
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("axis", 0LL);
test.AddInput<float>("data", {3, 4},
{0.0f, 1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f});
test.AddInput<int32_t>("indices", {1, 2}, {0LL, 1LL}); // invalid rank for ShrunkenGather
test.AddOutput<float>("output", {1, 2, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
test.Run(OpTester::ExpectResult::kExpectFailure, "ShrunkenGather only support 1D indices, got 2-D indices", {},
nullptr,
&execution_providers);
}
TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidInputRank) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
execution_providers.emplace_back(DefaultCudaExecutionProvider());
#endif
#ifdef USE_ROCM
execution_providers.emplace_back(DefaultRocmExecutionProvider());
#endif
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("axis", 0LL);
test.AddInput<float>("data", {}, // invalid rank for ShrunkenGather
{1.f});
test.AddInput<int64_t>("indices", {1}, {0LL});
test.AddOutput<float>("output", {}, {0.f});
test.Run(OpTester::ExpectResult::kExpectFailure, "data tensor must have rank >= 1", {},
nullptr,
&execution_providers);
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -752,6 +752,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
SrcNodeAttributes())};
}
IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
return std::vector<NodeDef>{
NodeDef("Shape",
{I(0)},
{IA("I0_shape")}),
NodeDef(OpDef{"GatherGrad", kMSDomain, 1},
{IA("I0_shape"), I(1), GO(0)},
{GI(0)},
SrcNodeAttributes())};
}
IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) {
return std::vector<NodeDef>{
NodeDef("Shape",

View file

@ -39,6 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient)
DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
DECLARE_GRADIENT_BUILDER(GetSqueezeGradient)

View file

@ -70,6 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
REGISTER_GRADIENT_BUILDER("Unsqueeze", GetUnsqueezeGradient);