Implement WhereGrad (#3343)

This commit is contained in:
Sherlock 2020-03-27 19:10:40 -07:00 committed by GitHub
parent 49e6043d07
commit ffb2a3359e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 60 additions and 8 deletions

View file

@ -36,7 +36,10 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"Slice", {1, 2, 3, 4}},
{"SparseSoftmaxCrossEntropy", {1, 2}},
{"ConstantOfShape", {0}},
{"Scatter", {1}}};
{"Scatter", {1}},
{"OneHot", {0, 1, 2}},
{"Where", {0}},
{"Range", {0, 1, 2}}};
class GradientGraphBuilder {
public:

View file

@ -920,6 +920,22 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) {
{GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetWhereGradient) {
std::vector<NodeDef> result;
const int64_t data_type = static_cast<int64_t>(I(1).type_proto->tensor_type().elem_type());
if (IsGradientRequiredForSrcNodeInput(1)) {
result.push_back(NodeDef("Cast", {I(0)}, {IA("Positive_Mask")}, {MakeAttribute("to", data_type)}));
result.push_back(NodeDef("Mul", {GO(0), IA("Positive_Mask")}, {GI(1)}));
}
if (IsGradientRequiredForSrcNodeInput(2)) {
result.push_back(NodeDef("Not", {I(0)}, {IA("Not_Condition", IType(0))}));
result.push_back(NodeDef("Cast", {IA("Not_Condition")}, {IA("Negative_Mask")}, {MakeAttribute("to", data_type)}));
result.push_back(NodeDef("Mul", {GO(0), IA("Negative_Mask")}, {GI(2)}));
}
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) {
// Send inputs: signal A, remote, data; outputs: signal B
// Recv inputs: signal B, remote; outputs: signal A', data'

View file

@ -50,6 +50,7 @@ DECLARE_GRADIENT_BUILDER(GetMegatronFGradient)
DECLARE_GRADIENT_BUILDER(GetMegatronGGradient)
DECLARE_GRADIENT_BUILDER(GetSliceGradient)
DECLARE_GRADIENT_BUILDER(GetFastGeluGradient)
DECLARE_GRADIENT_BUILDER(GetWhereGradient)
DECLARE_GRADIENT_BUILDER(GetSendGradient)
DECLARE_GRADIENT_BUILDER(GetRecvGradient)

View file

@ -11,7 +11,6 @@ namespace training {
GradientDef GetGradientForOp(const Node* node,
const std::unordered_set<std::string>& output_args_need_grad,
const std::unordered_set<std::string>& input_args_need_grad) {
// REVIEW(mzs): The below condition does not seem correct, it needs to be >= GRADIENT_OP_VERSION
// but changing it will break bunch of tests since many operators like sqrt are version 6,
// yet have a grad operator. However changing the opset requires changing the operator
@ -23,7 +22,7 @@ GradientDef GetGradientForOp(const Node* node,
// REVIEW(bahuang): We don't have a version control for forward to backward op mapping.
// Current SliceGrad(kMSDomain, 1) only supports Slice(kOnnxDomain, 10/11) because adding grad operator for versions
// less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11.
/*ORT_ENFORCE(
node->Op()->SinceVersion() <= GRADIENT_OP_VERSION,
"Gradients are supported for opset version" + std::to_string(node->Op()->SinceVersion()) +
@ -89,6 +88,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("MegatronG", GetMegatronGGradient);
REGISTER_GRADIENT_BUILDER("Slice", GetSliceGradient);
REGISTER_GRADIENT_BUILDER("FastGelu", GetFastGeluGradient);
REGISTER_GRADIENT_BUILDER("Where", GetWhereGradient);
REGISTER_GRADIENT_BUILDER("Send", GetSendGradient);
REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient);
};

View file

@ -68,7 +68,6 @@ inline std::vector<OrtValue> GradientChecker<X_T, Y_T, JAC_T>::EvaluateFunctionA
const std::vector<TensorInfo>& y_infos,
std::vector<std::vector<X_T>>* x_datas,
std::vector<std::vector<Y_T>>* y_datas) {
// clear OpTester input/output/initializer_index
op_session.ClearData();
@ -84,6 +83,12 @@ inline std::vector<OrtValue> GradientChecker<X_T, Y_T, JAC_T>::EvaluateFunctionA
std::vector<int32_t> int32_data(data.size());
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
for (size_t i = 0; i < data.size(); ++i) {
p_data[i] = static_cast<bool>(data[i]);
}
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
} else {
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
}
@ -140,6 +145,12 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
std::vector<int32_t> int32_data(data.size());
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
for (size_t i = 0; i < data.size(); ++i) {
p_data[i] = static_cast<bool>(data[i]);
}
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
} else {
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
}
@ -193,7 +204,6 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
std::vector<std::vector<Y_T>>* y_datas,
const std::vector<AttributeProto>& attributes,
const std::unordered_map<std::string, int>& extra_domain_to_version) {
for (size_t data_index = 0; data_index < x_datas->size(); data_index++) {
std::string name = "input" + std::to_string(data_index);
const std::vector<X_T>& data = (*x_datas)[data_index];
@ -206,6 +216,12 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
std::vector<int32_t> int32_data(data.size());
std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast<int32_t>(x); });
op_session.AddInput<int32_t>(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data);
} else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType<bool>()) {
std::unique_ptr<bool[]> p_data(new bool[data.size()]);
for (size_t i = 0; i < data.size(); ++i) {
p_data[i] = static_cast<bool>(data[i]);
}
op_session.AddInput<bool>(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size());
} else {
op_session.AddInput<X_T>(name.c_str(), x_infos[data_index].shape.GetDims(), data);
}
@ -228,8 +244,8 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGraph(
status = graph.Resolve();
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Resolve failed with status: " << status.ErrorMessage();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
LOGS_DEFAULT(ERROR) << "Resolve failed with status: " << status.ErrorMessage();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
}
if (!status.IsOK()) {
@ -251,7 +267,6 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGradGraph(
std::vector<std::vector<X_T>>* x_datas,
std::vector<std::vector<Y_T>>* y_datas,
const std::vector<AttributeProto>& attributes) {
std::unordered_map<std::string, int> extra_domain_to_version{{kMSDomain, 1}, {kOnnxDomain, 9}};
InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes, extra_domain_to_version);

View file

@ -1446,6 +1446,23 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) {
}
#endif
TEST(GradientCheckerTest, WhereGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Where"};
std::vector<int64_t> shape{4, 3, 2};
TensorInfo x_info(shape), y_info(shape);
std::function<float(float)> transformer = [](float x) {
return static_cast<float>(std::fmod(std::fabs(x), 1.0f) > 0.5f);
};
TensorInfo condition_info(shape, false, &transformer, DataTypeImpl::GetTensorType<bool>());
TensorShape output_shape{shape};
gradient_checker.ComputeGradientError(op_def, {condition_info, x_info, y_info}, {output_shape}, &max_error);
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, SliceGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;