mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Implement WhereGrad (#3343)
This commit is contained in:
parent
49e6043d07
commit
ffb2a3359e
6 changed files with 60 additions and 8 deletions
|
|
@ -36,7 +36,10 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
|
|||
{"Slice", {1, 2, 3, 4}},
|
||||
{"SparseSoftmaxCrossEntropy", {1, 2}},
|
||||
{"ConstantOfShape", {0}},
|
||||
{"Scatter", {1}}};
|
||||
{"Scatter", {1}},
|
||||
{"OneHot", {0, 1, 2}},
|
||||
{"Where", {0}},
|
||||
{"Range", {0, 1, 2}}};
|
||||
|
||||
class GradientGraphBuilder {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -920,6 +920,22 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) {
|
|||
{GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetWhereGradient) {
|
||||
std::vector<NodeDef> result;
|
||||
const int64_t data_type = static_cast<int64_t>(I(1).type_proto->tensor_type().elem_type());
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
result.push_back(NodeDef("Cast", {I(0)}, {IA("Positive_Mask")}, {MakeAttribute("to", data_type)}));
|
||||
result.push_back(NodeDef("Mul", {GO(0), IA("Positive_Mask")}, {GI(1)}));
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(2)) {
|
||||
result.push_back(NodeDef("Not", {I(0)}, {IA("Not_Condition", IType(0))}));
|
||||
result.push_back(NodeDef("Cast", {IA("Not_Condition")}, {IA("Negative_Mask")}, {MakeAttribute("to", data_type)}));
|
||||
result.push_back(NodeDef("Mul", {GO(0), IA("Negative_Mask")}, {GI(2)}));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) {
|
||||
// Send inputs: signal A, remote, data; outputs: signal B
|
||||
// Recv inputs: signal B, remote; outputs: signal A', data'
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ DECLARE_GRADIENT_BUILDER(GetMegatronFGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetMegatronGGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSliceGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetFastGeluGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetWhereGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSendGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetRecvGradient)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ namespace training {
|
|||
GradientDef GetGradientForOp(const Node* node,
|
||||
const std::unordered_set<std::string>& output_args_need_grad,
|
||||
const std::unordered_set<std::string>& input_args_need_grad) {
|
||||
|
||||
// REVIEW(mzs): The below condition does not seem correct, it needs to be >= GRADIENT_OP_VERSION
|
||||
// but changing it will break bunch of tests since many operators like sqrt are version 6,
|
||||
// yet have a grad operator. However changing the opset requires changing the operator
|
||||
|
|
@ -23,7 +22,7 @@ GradientDef GetGradientForOp(const Node* node,
|
|||
// REVIEW(bahuang): We don't have a version control for forward to backward op mapping.
|
||||
// Current SliceGrad(kMSDomain, 1) only supports Slice(kOnnxDomain, 10/11) because adding grad operator for versions
|
||||
// less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11.
|
||||
|
||||
|
||||
/*ORT_ENFORCE(
|
||||
node->Op()->SinceVersion() <= GRADIENT_OP_VERSION,
|
||||
"Gradients are supported for opset version" + std::to_string(node->Op()->SinceVersion()) +
|
||||
|
|
@ -89,6 +88,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("MegatronG", GetMegatronGGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Slice", GetSliceGradient);
|
||||
REGISTER_GRADIENT_BUILDER("FastGelu", GetFastGeluGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Where", GetWhereGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Send", GetSendGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ inline std::vector<OrtValue> GradientChecker<X_T, Y_T, JAC_T>::EvaluateFunctionA
|
|||
const std::vector<TensorInfo>& y_infos,
|
||||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas) {
|
||||
|
||||
// clear OpTester input/output/initializer_index
|
||||
op_session.ClearData();
|
||||
|
||||
|
|
@ -84,6 +83,12 @@ inline std::vector<OrtValue> GradientChecker<X_T, Y_T, JAC_T>::EvaluateFunctionA
|
|||
std::vector<int32_t> int32_data(data.size());
|
||||
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
|
||||
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
|
||||
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
|
||||
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
p_data[i] = static_cast<bool>(data[i]);
|
||||
}
|
||||
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
|
||||
} else {
|
||||
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
|
||||
}
|
||||
|
|
@ -140,6 +145,12 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
|
|||
std::vector<int32_t> int32_data(data.size());
|
||||
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
|
||||
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
|
||||
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
|
||||
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
p_data[i] = static_cast<bool>(data[i]);
|
||||
}
|
||||
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
|
||||
} else {
|
||||
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
|
||||
}
|
||||
|
|
@ -193,7 +204,6 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
|
|||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
const std::vector<AttributeProto>& attributes,
|
||||
const std::unordered_map<std::string, int>& extra_domain_to_version) {
|
||||
|
||||
for (size_t data_index = 0; data_index < x_datas->size(); data_index++) {
|
||||
std::string name = "input" + std::to_string(data_index);
|
||||
const std::vector<X_T>& data = (*x_datas)[data_index];
|
||||
|
|
@ -206,6 +216,12 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
|
|||
std::vector<int32_t> int32_data(data.size());
|
||||
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
|
||||
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
|
||||
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
|
||||
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
p_data[i] = static_cast<bool>(data[i]);
|
||||
}
|
||||
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
|
||||
} else {
|
||||
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
|
||||
}
|
||||
|
|
@ -228,8 +244,8 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
|
|||
status = graph.Resolve();
|
||||
|
||||
if (!status.IsOK()) {
|
||||
LOGS_DEFAULT(ERROR) << "Resolve failed with status: " << status.ErrorMessage();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
LOGS_DEFAULT(ERROR) << "Resolve failed with status: " << status.ErrorMessage();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
}
|
||||
|
||||
if (!status.IsOK()) {
|
||||
|
|
@ -251,7 +267,6 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGradGraph(
|
|||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
const std::vector<AttributeProto>& attributes) {
|
||||
|
||||
std::unordered_map<std::string, int> extra_domain_to_version{{kMSDomain, 1}, {kOnnxDomain, 9}};
|
||||
InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes, extra_domain_to_version);
|
||||
|
||||
|
|
|
|||
|
|
@ -1446,6 +1446,23 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) {
|
|||
}
|
||||
#endif
|
||||
|
||||
TEST(GradientCheckerTest, WhereGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Where"};
|
||||
|
||||
std::vector<int64_t> shape{4, 3, 2};
|
||||
TensorInfo x_info(shape), y_info(shape);
|
||||
std::function<float(float)> transformer = [](float x) {
|
||||
return static_cast<float>(std::fmod(std::fabs(x), 1.0f) > 0.5f);
|
||||
};
|
||||
TensorInfo condition_info(shape, false, &transformer, DataTypeImpl::GetTensorType<bool>());
|
||||
|
||||
TensorShape output_shape{shape};
|
||||
gradient_checker.ComputeGradientError(op_def, {condition_info, x_info, y_info}, {output_shape}, &max_error);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, SliceGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
|
|
|||
Loading…
Reference in a new issue