mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Introduce shrunken gather operator (#15396)
### Introduce shrunken gather operator Exist Gather operator schema won't guarantee output element count will be smaller than input element count. Actually, it is possible output element count >, =, or < input element count. For some cases we know for sure output element count MUST be <= input element count, we will upstream those Gather operators to reduce compute flops. So this PR introduces an ShrunkenGather which explicitly guarantee output count will be smaller than input count. The operator add additional restriction on inputs, but still re-use existing Gather's implementations plus input check during runtime. This is a requirement for subsequent optimization (Draft PR: https://github.com/microsoft/onnxruntime/pull/15401) we will do for label sparsity and embedding sparsity.
This commit is contained in:
parent
d31dd5935a
commit
16f5909f2d
18 changed files with 415 additions and 12 deletions
|
|
@ -86,7 +86,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNH
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalAveragePool);
|
||||
#endif
|
||||
|
||||
|
||||
// This section includes all op kernel declarations for former experimental ops which have now been removed from onnx.
|
||||
// To maintain backward compatibility these are added as contrib ops.
|
||||
// Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops
|
||||
|
|
@ -122,6 +121,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ShrunkenGather);
|
||||
#endif
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
KernelCreateInfo info;
|
||||
|
|
@ -151,7 +156,6 @@ Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
|
||||
Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -171,8 +175,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
|
|||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
|
|
@ -283,6 +285,13 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
#ifdef ENABLE_ATEN
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
55
onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.cc
Normal file
55
onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.cc
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
|
||||
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
ShrunkenGather,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
ShrunkenGather);
|
||||
|
||||
void ShrunkenGatherCommon::CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const {
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indices_shape = indices_tensor->Shape();
|
||||
|
||||
ORT_ENFORCE(input_shape.NumDimensions() >= 1, "ShrunkenGather only support input with rank >= 1, got ",
|
||||
input_shape.NumDimensions(), "-D input");
|
||||
|
||||
ORT_ENFORCE(indices_shape.NumDimensions() == 1, "ShrunkenGather only support 1D indices, got ",
|
||||
indices_shape.NumDimensions(), "-D indices");
|
||||
|
||||
const auto input_rank = input_shape.NumDimensions();
|
||||
auto axis = HandleNegativeAxis(axis_in, narrow<int64_t>(input_rank));
|
||||
|
||||
const int64_t N = indices_shape.Size();
|
||||
const int64_t indices_max = input_shape[axis];
|
||||
ORT_ENFORCE(indices_max >= N, "ShrunkenGather indices elem count should <= input dim on axis: ", axis,
|
||||
", got indices elem count:", N, " input dim: ", indices_max);
|
||||
}
|
||||
|
||||
Status ShrunkenGather::Compute(OpKernelContext* context) const {
|
||||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
|
||||
ShrunkenGatherCommon::CheckInput(p.input_tensor, p.indices_tensor, p.axis);
|
||||
return Gather::Compute(context);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
31
onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.h
Normal file
31
onnxruntime/contrib_ops/cpu/tensor/shrunken_gather.h
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/tensor/gather.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class ShrunkenGatherCommon {
|
||||
public:
|
||||
void CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const;
|
||||
};
|
||||
|
||||
class ShrunkenGather final : public onnxruntime::Gather, public ShrunkenGatherCommon {
|
||||
public:
|
||||
ShrunkenGather(const OpKernelInfo& info) : Gather(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
|
|
@ -140,6 +140,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather);
|
||||
#endif
|
||||
|
||||
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
|
||||
|
|
@ -288,6 +294,12 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
|
||||
#endif
|
||||
|
||||
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
|
||||
|
|
|
|||
40
onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.cc
Normal file
40
onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.cc
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
|
||||
#include "contrib_ops/cuda/tensor/shrunken_gather.h"
|
||||
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
ShrunkenGather,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
ShrunkenGather);
|
||||
|
||||
Status ShrunkenGather::ComputeInternal(OpKernelContext* context) const {
|
||||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
|
||||
ShrunkenGatherCommon::CheckInput(p.input_tensor, p.indices_tensor, p.axis);
|
||||
return onnxruntime::cuda::Gather::ComputeInternal(context);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
26
onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.h
Normal file
26
onnxruntime/contrib_ops/cuda/tensor/shrunken_gather.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/cuda/tensor/gather.h"
|
||||
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
class ShrunkenGather final : public onnxruntime::cuda::Gather, public ShrunkenGatherCommon {
|
||||
public:
|
||||
ShrunkenGather(const OpKernelInfo& info) : onnxruntime::cuda::Gather(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
|
|
@ -116,6 +116,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather);
|
||||
#endif
|
||||
|
||||
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather);
|
||||
|
|
@ -247,6 +253,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
#ifdef ENABLE_ATEN
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
|
||||
#endif
|
||||
|
||||
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather)>,
|
||||
|
|
|
|||
|
|
@ -424,16 +424,14 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
|||
auto& input_ids_dims = input_ids_shape.dim();
|
||||
auto model_type_attr = ctx.getAttribute("model_type");
|
||||
int64_t model_type = model_type_attr ? static_cast<int64_t>(model_type_attr->i()) : -1;
|
||||
if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) {
|
||||
if (input_ids_dims.size() != 3)
|
||||
{
|
||||
if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) {
|
||||
if (input_ids_dims.size() != 3) {
|
||||
fail_shape_inference("Inputs 0 shall be 3 dimensions in whisper graph");
|
||||
}
|
||||
}
|
||||
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value() && input_ids_dims[2].has_dim_value())) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if (input_ids_dims.size() != 2) {
|
||||
} else if (input_ids_dims.size() != 2) {
|
||||
fail_shape_inference("Inputs 0 shall be 2 dimensions", model_type);
|
||||
}
|
||||
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) {
|
||||
|
|
@ -2722,6 +2720,89 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t
|
|||
"Allow inputs and outputs to be any kind of tensor.");
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
|
||||
static const char* ShrunkenGather_ver1_doc = R"DOC(
|
||||
This op is a specialised case of Gather-13, adding additional constraint including: indices being 1D,
|
||||
and indices count < input element count on the specified axis.
|
||||
|
||||
Having this op allows runtime to do operator re-ordering to reduce compute FLOPs.
|
||||
)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ShrunkenGather)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.SetDoc(ShrunkenGather_ver1_doc)
|
||||
.AllowUncheckedAttributes()
|
||||
.Attr(
|
||||
"axis",
|
||||
"Which axis to gather on. Negative value means "
|
||||
"counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Input(0, "data", "Tensor of rank r >= 1.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
|
||||
.Input(
|
||||
1,
|
||||
"indices",
|
||||
"Tensor of int64 indices, with rank = 1. All index values are expected to be within bounds [-s, s-1] "
|
||||
"along axis of size s. It is an error if any of the index values are out of bounds."
|
||||
"The number of elements in indices must be less than the number of elements in the input tensor,"
|
||||
"which is the reason why this op is called ShrunkenGather.",
|
||||
"Tind",
|
||||
OpSchema::Single,
|
||||
true,
|
||||
1,
|
||||
OpSchema::NonDifferentiable)
|
||||
.Output(0, "output", "Tensor of rank r.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types_with_bfloat(),
|
||||
"Constrain input and output types to any tensor type.")
|
||||
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types")
|
||||
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
if (!hasNInputShapes(ctx, 2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const TensorShapeProto& data_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
const TensorShapeProto& indices_shape = ctx.getInputType(1)->tensor_type().shape();
|
||||
int r = data_shape.dim_size();
|
||||
if (r < 1) {
|
||||
fail_shape_inference("data tensor must have rank >= 1");
|
||||
}
|
||||
int q = indices_shape.dim_size();
|
||||
int axis = static_cast<int>(getAttribute(ctx, "axis", 0));
|
||||
if (axis < -r || axis >= r) {
|
||||
fail_shape_inference("axis must be in [-r, r-1]");
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += r;
|
||||
}
|
||||
|
||||
int out_rank = q + r - 1;
|
||||
auto final_output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
|
||||
int i = 0;
|
||||
for (; i < axis; ++i) {
|
||||
*final_output_shape->add_dim() = data_shape.dim(i);
|
||||
}
|
||||
|
||||
for (; i < axis + q; ++i) {
|
||||
*final_output_shape->add_dim() = indices_shape.dim(i - axis);
|
||||
}
|
||||
|
||||
for (; i < out_rank; ++i) {
|
||||
*final_output_shape->add_dim() = data_shape.dim(i - q + 1);
|
||||
}
|
||||
});
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef _OPSCHEMA_LIB_
|
||||
// Register the NCHWc schemas if supported by the platform.
|
||||
if (MlasNchwcGetBlockSize() > 1) {
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@
|
|||
#include "orttraining/training_ops/cpu/tensor/split.h"
|
||||
#include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h"
|
||||
#include "orttraining/training_ops/cpu/optimizer/sgd/sgdbase.h"
|
||||
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
@ -275,6 +279,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
|
|||
return p->SGDOptimizerV2Base::PrepareForCompute(ctx,
|
||||
reinterpret_cast<contrib::SGDOptimizerV2Base::Prepare&>(prepare));
|
||||
}
|
||||
void contrib__ShrunkenGatherCommon__CheckInput(const contrib::ShrunkenGatherCommon* p, const Tensor* input_tensor,
|
||||
const Tensor* indices_tensor, int64_t axis_in) const override {
|
||||
return p->ShrunkenGatherCommon::CheckInput(input_tensor, indices_tensor, axis_in);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
|
|||
|
|
@ -194,6 +194,12 @@ struct ProviderHostCPU {
|
|||
virtual Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, contrib__AdamWOptimizerBase__Prepare& prepare) = 0;
|
||||
// From cpu/optimizer/sgdbase.h
|
||||
virtual Status contrib__SGDOptimizerV2Base__PrepareForCompute(const contrib::SGDOptimizerV2Base* p, OpKernelContext* ctx, contrib__SGDOptimizerV2Base__Prepare& prepare) = 0;
|
||||
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
virtual void contrib__ShrunkenGatherCommon__CheckInput(const contrib::ShrunkenGatherCommon* p,
|
||||
const Tensor* input_tensor, const Tensor* indices_tensor,
|
||||
int64_t axis_in) const = 0;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
class Gather final : public OpKernel, public GatherBase {
|
||||
class Gather : public OpKernel, public GatherBase {
|
||||
public:
|
||||
Gather(const OpKernelInfo& info) : OpKernel(info), GatherBase(info) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class Gather final : public CudaKernel, public GatherBase {
|
||||
class Gather : public CudaKernel, public GatherBase {
|
||||
public:
|
||||
Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ class PassThrough;
|
|||
class YieldOp;
|
||||
class AdamWOptimizerBase;
|
||||
class SGDOptimizerV2Base;
|
||||
class ShrunkenGatherCommon;
|
||||
} // namespace contrib
|
||||
|
||||
class UnsqueezeBase;
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@
|
|||
#include "orttraining/training_ops/cpu/controlflow/group.h"
|
||||
#include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h"
|
||||
#include "orttraining/training_ops/cpu/optimizer/sgd/sgdbase.h"
|
||||
|
||||
// Should remove the include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
#include "contrib_ops/cpu/tensor/shrunken_gather.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
@ -647,6 +651,9 @@ Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx, AdamWOptimize
|
|||
Status SGDOptimizerV2Base::PrepareForCompute(OpKernelContext* ctx, SGDOptimizerV2Base::Prepare& prepare) const {
|
||||
return g_host_cpu.contrib__SGDOptimizerV2Base__PrepareForCompute(this, ctx, reinterpret_cast<contrib__SGDOptimizerV2Base__Prepare&>(prepare));
|
||||
}
|
||||
void ShrunkenGatherCommon::CheckInput(const Tensor* input_tensor, const Tensor* indices_tensor, int64_t axis_in) const {
|
||||
return g_host_cpu.contrib__ShrunkenGatherCommon__CheckInput(this, input_tensor, indices_tensor, axis_in);
|
||||
}
|
||||
} // namespace contrib
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -423,5 +423,105 @@ TEST(GatherOpTest, Gather_axis1_neg_indices2d_int8) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Assertion `regionRanges != nullptr' failed
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
|
||||
TEST(ShrunkenGatherOpTest, ShrunkenGather_PositiveAxis) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.emplace_back(DefaultCpuExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
execution_providers.emplace_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
|
||||
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 0LL);
|
||||
test.AddInput<float>("data", {3, 4},
|
||||
{0.0f, 1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f});
|
||||
test.AddInput<int32_t>("indices", {2}, {1LL, 0LL});
|
||||
test.AddOutput<float>("output", {2, 4}, {4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 1.0f, 2.0f, 3.0f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {},
|
||||
nullptr,
|
||||
&execution_providers);
|
||||
}
|
||||
|
||||
TEST(ShrunkenGatherOpTest, ShrunkenGather_NegativeAxis) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.emplace_back(DefaultCpuExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
execution_providers.emplace_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
|
||||
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", -1LL);
|
||||
test.AddInput<float>("data", {3, 4},
|
||||
{0.0f, 1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f});
|
||||
test.AddInput<int32_t>("indices", {2}, {0LL, 3LL});
|
||||
test.AddOutput<float>("output", {3, 2}, {0.0f, 3.0f, 4.0f, 7.0f, 8.0f, 11.0f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {},
|
||||
nullptr,
|
||||
&execution_providers);
|
||||
}
|
||||
|
||||
TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidIndicesRank) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.emplace_back(DefaultCpuExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
execution_providers.emplace_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
|
||||
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 0LL);
|
||||
test.AddInput<float>("data", {3, 4},
|
||||
{0.0f, 1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f});
|
||||
test.AddInput<int32_t>("indices", {1, 2}, {0LL, 1LL}); // invalid rank for ShrunkenGather
|
||||
test.AddOutput<float>("output", {1, 2, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "ShrunkenGather only support 1D indices, got 2-D indices", {},
|
||||
nullptr,
|
||||
&execution_providers);
|
||||
}
|
||||
|
||||
TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidInputRank) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.emplace_back(DefaultCpuExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.emplace_back(DefaultCudaExecutionProvider());
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
execution_providers.emplace_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
|
||||
OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute<int64_t>("axis", 0LL);
|
||||
test.AddInput<float>("data", {}, // invalid rank for ShrunkenGather
|
||||
{1.f});
|
||||
test.AddInput<int64_t>("indices", {1}, {0LL});
|
||||
test.AddOutput<float>("output", {}, {0.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "data tensor must have rank >= 1", {},
|
||||
nullptr,
|
||||
&execution_providers);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -752,6 +752,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
|
|||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Shape",
|
||||
{I(0)},
|
||||
{IA("I0_shape")}),
|
||||
NodeDef(OpDef{"GatherGrad", kMSDomain, 1},
|
||||
{IA("I0_shape"), I(1), GO(0)},
|
||||
{GI(0)},
|
||||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetGatherElementsGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Shape",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConvGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSqueezeGradient)
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
|
||||
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Unsqueeze", GetUnsqueezeGradient);
|
||||
|
|
|
|||
Loading…
Reference in a new issue