diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 47732d41cd..a44cf60f9e 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -86,7 +86,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNH class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalAveragePool); #endif - // This section includes all op kernel declarations for former experimental ops which have now been removed from onnx. // To maintain backward compatibility these are added as contrib ops. // Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops @@ -122,6 +121,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ShrunkenGather); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -151,7 +156,6 @@ Status RegisterNchwcKernels(KernelRegistry& kernel_registry) { return Status::OK(); } - #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -171,8 +175,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { } #endif - - Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing @@ -283,6 +285,13 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { #ifdef ENABLE_ATEN BuildKernelCreateInfo, #endif + +#ifdef ENABLE_TRAINING_OPS + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + BuildKernelCreateInfo, +#endif + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.cc b/onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.cc new file mode 100644 index 0000000000..b8885a0e1f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. + +#include "contrib_ops/cpu/tensor/shrunken_gather.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + ShrunkenGather, + kMSDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", + std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ShrunkenGather); + +void ShrunkenGatherCommon::CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const { + const auto& input_shape = input_tensor->Shape(); + const auto& indices_shape = indices_tensor->Shape(); + + ORT_ENFORCE(input_shape.NumDimensions() >= 1, "ShrunkenGather only support input with rank >= 1, got ", + input_shape.NumDimensions(), "-D input"); + + ORT_ENFORCE(indices_shape.NumDimensions() == 1, "ShrunkenGather only support 1D indices, got ", + indices_shape.NumDimensions(), "-D indices"); + + const auto input_rank = input_shape.NumDimensions(); + auto axis = HandleNegativeAxis(axis_in, narrow(input_rank)); + + const int64_t N = indices_shape.Size(); + const int64_t indices_max = input_shape[axis]; + ORT_ENFORCE(indices_max >= N, "ShrunkenGather indices elem count should <= input dim on axis: ", axis, + ", got indices elem count:", N, " input dim: ", indices_max); +} + +Status ShrunkenGather::Compute(OpKernelContext* context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); + ShrunkenGatherCommon::CheckInput(p.input_tensor, p.indices_tensor, p.axis); + return Gather::Compute(context); +} + +} // namespace contrib +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.h b/onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.h new file mode 100644 index 0000000000..33e761a754 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/gather.h" + +namespace onnxruntime { +namespace contrib { + +class ShrunkenGatherCommon { + public: + void CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const; +}; + +class ShrunkenGather final : public onnxruntime::Gather, public ShrunkenGatherCommon { + public: + ShrunkenGather(const OpKernelInfo& info) : Gather(info) {} + + Status Compute(OpKernelContext* context) const override; +}; +} // namespace contrib +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 0b800c78bc..d9a77c7bab 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -140,6 +140,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather); +#endif + #if defined(USE_MPI) && defined(ORT_USE_NCCL) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); @@ -288,6 +294,12 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif +#ifdef ENABLE_TRAINING_OPS + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + BuildKernelCreateInfo, +#endif + #if defined(USE_MPI) && defined(ORT_USE_NCCL) BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.cc b/onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.cc new file mode 100644 index 0000000000..ae7e55b231 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. + +#include "contrib_ops/cuda/tensor/shrunken_gather.h" +#include "contrib_ops/cpu/tensor/shrunken_gather.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +ONNX_OPERATOR_KERNEL_EX( + ShrunkenGather, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ShrunkenGather); + +Status ShrunkenGather::ComputeInternal(OpKernelContext* context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); + ShrunkenGatherCommon::CheckInput(p.input_tensor, p.indices_tensor, p.axis); + return onnxruntime::cuda::Gather::ComputeInternal(context); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.h b/onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.h new file mode 100644 index 0000000000..f64da9516e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/tensor/gather.h" +#include "contrib_ops/cpu/tensor/shrunken_gather.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +class ShrunkenGather final : public onnxruntime::cuda::Gather, public ShrunkenGatherCommon { + public: + ShrunkenGather(const OpKernelInfo& info) : onnxruntime::cuda::Gather(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 062aa32363..a3efd80668 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -116,6 +116,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); +#endif + #if defined(USE_MPI) && defined(ORT_USE_NCCL) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); @@ -247,6 +253,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { #ifdef ENABLE_ATEN BuildKernelCreateInfo, #endif + +#ifdef ENABLE_TRAINING_OPS + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + BuildKernelCreateInfo, +#endif + #if defined(USE_MPI) && defined(ORT_USE_NCCL) BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 728de78de3..95bb6b5937 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -424,16 +424,14 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { auto& input_ids_dims = input_ids_shape.dim(); auto model_type_attr = ctx.getAttribute("model_type"); int64_t model_type = model_type_attr ? static_cast(model_type_attr->i()) : -1; - if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) { - if (input_ids_dims.size() != 3) - { + if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) { + if (input_ids_dims.size() != 3) { fail_shape_inference("Inputs 0 shall be 3 dimensions in whisper graph"); - } + } if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value() && input_ids_dims[2].has_dim_value())) { return; } - } - else if (input_ids_dims.size() != 2) { + } else if (input_ids_dims.size() != 2) { fail_shape_inference("Inputs 0 shall be 2 dimensions", model_type); } if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) { @@ -2722,6 +2720,89 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t "Allow inputs and outputs to be any kind of tensor."); #endif +#ifdef ENABLE_TRAINING_OPS + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + + static const char* ShrunkenGather_ver1_doc = R"DOC( + This op is a specialised case of Gather-13, adding additional constraint including: indices being 1D, +and indices count < input element count on the specified axis. + +Having this op allows runtime to do operator re-ordering to reduce compute FLOPs. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(ShrunkenGather) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc(ShrunkenGather_ver1_doc) + .AllowUncheckedAttributes() + .Attr( + "axis", + "Which axis to gather on. Negative value means " + "counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "Tensor of rank r >= 1.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "indices", + "Tensor of int64 indices, with rank = 1. All index values are expected to be within bounds [-s, s-1] " + "along axis of size s. It is an error if any of the index values are out of bounds." + "The number of elements in indices must be less than the number of elements in the input tensor," + "which is the reason why this op is called ShrunkenGather.", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Tensor of rank r.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint( + "T", + OpSchema::all_tensor_types_with_bfloat(), + "Constrain input and output types to any tensor type.") + .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasNInputShapes(ctx, 2)) { + return; + } + + const TensorShapeProto& data_shape = ctx.getInputType(0)->tensor_type().shape(); + const TensorShapeProto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); + int r = data_shape.dim_size(); + if (r < 1) { + fail_shape_inference("data tensor must have rank >= 1"); + } + int q = indices_shape.dim_size(); + int axis = static_cast(getAttribute(ctx, "axis", 0)); + if (axis < -r || axis >= r) { + fail_shape_inference("axis must be in [-r, r-1]"); + } + if (axis < 0) { + axis += r; + } + + int out_rank = q + r - 1; + auto final_output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + int i = 0; + for (; i < axis; ++i) { + *final_output_shape->add_dim() = data_shape.dim(i); + } + + for (; i < axis + q; ++i) { + *final_output_shape->add_dim() = indices_shape.dim(i - axis); + } + + for (; i < out_rank; ++i) { + *final_output_shape->add_dim() = data_shape.dim(i - q + 1); + } + }); + +#endif + #ifndef _OPSCHEMA_LIB_ // Register the NCHWc schemas if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index c0a75fc50b..d3d252cf70 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -45,6 +45,10 @@ #include "orttraining/training_ops/cpu/tensor/split.h" #include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h" #include "orttraining/training_ops/cpu/optimizer/sgd/sgdbase.h" + +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +#include "contrib_ops/cpu/tensor/shrunken_gather.h" #endif #ifdef ENABLE_TRAINING @@ -275,6 +279,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU { return p->SGDOptimizerV2Base::PrepareForCompute(ctx, reinterpret_cast(prepare)); } + void contrib__ShrunkenGatherCommon__CheckInput(const contrib::ShrunkenGatherCommon* p, const Tensor* input_tensor, + const Tensor* indices_tensor, int64_t axis_in) const override { + return p->ShrunkenGatherCommon::CheckInput(input_tensor, indices_tensor, axis_in); + } + #endif #ifdef ENABLE_TRAINING diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 1ecbca7a55..1fafc646c8 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -194,6 +194,12 @@ struct ProviderHostCPU { virtual Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, contrib__AdamWOptimizerBase__Prepare& prepare) = 0; // From cpu/optimizer/sgdbase.h virtual Status contrib__SGDOptimizerV2Base__PrepareForCompute(const contrib::SGDOptimizerV2Base* p, OpKernelContext* ctx, contrib__SGDOptimizerV2Base__Prepare& prepare) = 0; + + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + virtual void contrib__ShrunkenGatherCommon__CheckInput(const contrib::ShrunkenGatherCommon* p, + const Tensor* input_tensor, const Tensor* indices_tensor, + int64_t axis_in) const = 0; #endif #ifdef ENABLE_TRAINING diff --git a/onnxruntime/core/providers/cpu/tensor/gather.h b/onnxruntime/core/providers/cpu/tensor/gather.h index 3508ff981d..0b99219f8a 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather.h +++ b/onnxruntime/core/providers/cpu/tensor/gather.h @@ -10,7 +10,7 @@ namespace onnxruntime { -class Gather final : public OpKernel, public GatherBase { +class Gather : public OpKernel, public GatherBase { public: Gather(const OpKernelInfo& info) : OpKernel(info), GatherBase(info) {} diff --git a/onnxruntime/core/providers/cuda/tensor/gather.h b/onnxruntime/core/providers/cuda/tensor/gather.h index c4a2a71f5a..e5bc6cecb0 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather.h +++ b/onnxruntime/core/providers/cuda/tensor/gather.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace cuda { -class Gather final : public CudaKernel, public GatherBase { +class Gather : public CudaKernel, public GatherBase { public: Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {} Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index c010fa0a83..01c1779af0 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -175,6 +175,7 @@ class PassThrough; class YieldOp; class AdamWOptimizerBase; class SGDOptimizerV2Base; +class ShrunkenGatherCommon; } // namespace contrib class UnsqueezeBase; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 817ae047da..97e5db9a03 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -43,6 +43,10 @@ #include "orttraining/training_ops/cpu/controlflow/group.h" #include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h" #include "orttraining/training_ops/cpu/optimizer/sgd/sgdbase.h" + +// Should remove the include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +#include "contrib_ops/cpu/tensor/shrunken_gather.h" #endif #ifdef ENABLE_TRAINING @@ -647,6 +651,9 @@ Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx, AdamWOptimize Status SGDOptimizerV2Base::PrepareForCompute(OpKernelContext* ctx, SGDOptimizerV2Base::Prepare& prepare) const { return g_host_cpu.contrib__SGDOptimizerV2Base__PrepareForCompute(this, ctx, reinterpret_cast(prepare)); } +void ShrunkenGatherCommon::CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const { + return g_host_cpu.contrib__ShrunkenGatherCommon__CheckInput(this, input_tensor, indices_tensor, axis_in); +} } // namespace contrib #endif diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index d49efa57eb..0f199d3fc9 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -423,5 +423,105 @@ TEST(GatherOpTest, Gather_axis1_neg_indices2d_int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Assertion `regionRanges != nullptr' failed } +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. + +TEST(ShrunkenGatherOpTest, ShrunkenGather_PositiveAxis) { + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#endif +#ifdef USE_ROCM + execution_providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); + test.AddAttribute("axis", 0LL); + test.AddInput("data", {3, 4}, + {0.0f, 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f}); + test.AddInput("indices", {2}, {1LL, 0LL}); + test.AddOutput("output", {2, 4}, {4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 1.0f, 2.0f, 3.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, + nullptr, + &execution_providers); +} + +TEST(ShrunkenGatherOpTest, ShrunkenGather_NegativeAxis) { + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#endif +#ifdef USE_ROCM + execution_providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); + test.AddAttribute("axis", -1LL); + test.AddInput("data", {3, 4}, + {0.0f, 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f}); + test.AddInput("indices", {2}, {0LL, 3LL}); + test.AddOutput("output", {3, 2}, {0.0f, 3.0f, 4.0f, 7.0f, 8.0f, 11.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, + nullptr, + &execution_providers); +} + +TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidIndicesRank) { + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#endif +#ifdef USE_ROCM + execution_providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); + test.AddAttribute("axis", 0LL); + test.AddInput("data", {3, 4}, + {0.0f, 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f}); + test.AddInput("indices", {1, 2}, {0LL, 1LL}); // invalid rank for ShrunkenGather + test.AddOutput("output", {1, 2, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "ShrunkenGather only support 1D indices, got 2-D indices", {}, + nullptr, + &execution_providers); +} + +TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidInputRank) { + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#endif +#ifdef USE_ROCM + execution_providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); + test.AddAttribute("axis", 0LL); + test.AddInput("data", {}, // invalid rank for ShrunkenGather + {1.f}); + test.AddInput("indices", {1}, {0LL}); + test.AddOutput("output", {}, {0.f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "data tensor must have rank >= 1", {}, + nullptr, + &execution_providers); +} + +#endif + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index c46f03f37e..512dfb2035 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -752,6 +752,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) { SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) { + return std::vector{ + NodeDef("Shape", + {I(0)}, + {IA("I0_shape")}), + NodeDef(OpDef{"GatherGrad", kMSDomain, 1}, + {IA("I0_shape"), I(1), GO(0)}, + {GI(0)}, + SrcNodeAttributes())}; +} + IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) { return std::vector{ NodeDef("Shape", diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 15aee6cb12..156ed7f8c8 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -39,6 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient) DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient) DECLARE_GRADIENT_BUILDER(GetGatherGradient) +DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient) DECLARE_GRADIENT_BUILDER(GetConvGradient) DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient) DECLARE_GRADIENT_BUILDER(GetSqueezeGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 028bdc9c47..51a7c75ef6 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -70,6 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); + REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient); REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient); REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient); REGISTER_GRADIENT_BUILDER("Unsqueeze", GetUnsqueezeGradient);