mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
ScatterNDGrad (#8261)
This commit is contained in:
parent
97f1eea2ea
commit
ef8f50c4ab
5 changed files with 121 additions and 2 deletions
|
|
@ -1836,5 +1836,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetPadGradient) {
|
|||
NodeDef("Pad", {GO(0), IA("Neg_pads")}, {GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetScatterNDGradient) {
|
||||
std::vector<NodeDef> result;
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
result.emplace_back(NodeDef("Shape", {I(2)}, {IA("Shape_updates")}));
|
||||
result.emplace_back(NodeDef("ConstantOfShape", {IA("Shape_updates")}, {IA("Zero_Shape_updates")},
|
||||
{MakeAttribute("value", ScalarTensorProtoByElemType(0.0f, IElemType(0)))}));
|
||||
result.emplace_back(NodeDef("ScatterND", {GO(0), I(1), IA("Zero_Shape_updates")}, {GI(0)}));
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(2)) {
|
||||
result.emplace_back(NodeDef("GatherND", {GO(0), I(1)}, {GI(2)}));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ DECLARE_GRADIENT_BUILDER(GetATenOpGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetPadGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetIdentityGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPythonOpGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetScatterNDGradient)
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -210,14 +210,19 @@ class GradientBuilderBase {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
|
||||
static ONNX_NAMESPACE::TensorProto ScalarTensorProto(T value, std::vector<int64_t> shape) {
|
||||
ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1));
|
||||
|
||||
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(value);
|
||||
for (auto dim : shape) {
|
||||
t_proto.add_dims(dim);
|
||||
}
|
||||
|
||||
return t_proto;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
|
||||
auto t_proto = ScalarTensorProto(value, shape);
|
||||
return NodeDef("Constant",
|
||||
{},
|
||||
{ArgDef(arg_name, nullptr)},
|
||||
|
|
@ -237,6 +242,18 @@ class GradientBuilderBase {
|
|||
return ConstantScalarNode(value, {1}, arg_name);
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) {
|
||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
return ScalarTensorProto(MLFloat16(math::floatToHalf(value)), {1});
|
||||
}
|
||||
|
||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||
return ScalarTensorProto(BFloat16(value), {1});
|
||||
}
|
||||
|
||||
return ScalarTensorProto(value, {1});
|
||||
}
|
||||
|
||||
static NodeDef ZeroConstantNode(int elem_type) {
|
||||
return ConstantScalarNode(0.0f, "ZeroConstant_Type" + std::to_string(elem_type), elem_type);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Pad", GetPadGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Identity", GetIdentityGradient);
|
||||
REGISTER_GRADIENT_BUILDER("PythonOp", GetPythonOpGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ScatterND", GetScatterNDGradient);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -2653,6 +2653,90 @@ TEST(GradientCheckerTest, PadGrad) {
|
|||
}
|
||||
#endif // USE_CUDA
|
||||
|
||||
TEST(GradientCheckerTest, ScatterNDGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"ScatterND", kOnnxDomain, 11};
|
||||
|
||||
{
|
||||
TensorInfo data_info({8}, true);
|
||||
TensorInfo indices_info({4, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo updates_info({4}, true);
|
||||
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {4, 3, 1, 7}, {8, 9, 10, 11}};
|
||||
|
||||
TensorInfo output_info({8}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
|
||||
{output_info}, &max_error, input_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo data_info({2, 2}, true);
|
||||
TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo updates_info({2}, true);
|
||||
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3}, {0, 0, 1, 1}, {4, 5}};
|
||||
|
||||
TensorInfo output_info({2, 2}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
|
||||
{output_info}, &max_error, input_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo data_info({2, 2}, true);
|
||||
TensorInfo indices_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo updates_info({2, 2}, true);
|
||||
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3}, {1, 0}, {4, 5, 6, 7}};
|
||||
|
||||
TensorInfo output_info({2, 2}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
|
||||
{output_info}, &max_error, input_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo data_info({2, 2, 2}, true);
|
||||
TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo updates_info({2, 2}, true);
|
||||
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 1, 0}, {8, 9, 10, 11}};
|
||||
|
||||
TensorInfo output_info({2, 2, 2}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
|
||||
{output_info}, &max_error, input_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo data_info({2, 2, 2}, true);
|
||||
TensorInfo indices_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo updates_info({2, 1, 2}, true);
|
||||
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 1, 0}, {8, 9, 10, 11}};
|
||||
|
||||
TensorInfo output_info({2, 2, 2}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
|
||||
{output_info}, &max_error, input_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo data_info({2, 2, 2}, true);
|
||||
TensorInfo indices_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
|
||||
TensorInfo updates_info({2, 2, 2}, true);
|
||||
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1}, {8, 9, 10, 11, 12, 13, 14, 15}};
|
||||
|
||||
TensorInfo output_info({2, 2, 2}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
|
||||
{output_info}, &max_error, input_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue