Fix DropoutGrad op (#4052)

Dropout op was recently changed to accept a new input named
'training_mode', which is passed in to DropoutGrad automatically.

This PR updates the DropoutGrad schema to accommodate the new input.
Tests were also update to reflect the API change

Co-authored-by: Thiago Crepaldi <thiag.crepaldi@microsoft.com>
This commit is contained in:
Thiago Crepaldi 2020-06-04 15:00:02 -07:00 committed by GitHub
parent 6199ef1375
commit 81101c9efd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 2 deletions

View file

@ -1078,6 +1078,12 @@ Example 4:
"the case during training.",
"T1",
OpSchema::Optional)
.Input(3, "training_mode",
"If set to true then it indicates dropout is being used for training. It is an optional value hence unless "
"specified explicitly, it is false. If it is false, ratio is ignored and the operation mimics inference mode where "
"nothing will be dropped from the input data and if mask is requested as output it will contain all ones.",
"T2",
OpSchema::Optional)
.Output(0, "dx", "Gradient of the input.", "T")
.TypeConstraint(
"T",
@ -1090,7 +1096,7 @@ Example 4:
.TypeConstraint(
"T2",
{"tensor(bool)"},
"Constrain 'mask' types to boolean tensors.")
"Constrain 'mask' and 'training_mode' types to boolean tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
});

View file

@ -1492,8 +1492,12 @@ void TestDropoutGradOp(float ratio, TensorShape& x_shape, bool default_ratio = t
true, false, true, false});
if (!default_ratio) {
test.AddInput<float>("ratio", {1}, ratio_data);
} else {
test.AddMissingOptionalInput<float>();
}
test.AddInput("training_mode", {}, {true});
test.AddOutput<float>("dx", x_shape.GetDims(), dx_data);
test.Run();

View file

@ -176,6 +176,12 @@ void RunDropoutGradTest(const char* op, float ratio, const std::vector<int64_t>&
test.AddInput<bool>("mask", input_shape.GetDims(), mask_buffer.get(), input_shape.Size());
if (!default_ratio) {
test.AddInput<float>("ratio", {1}, ratio_data);
} else {
test.AddMissingOptionalInput<float>();
}
if (strcmp(op, "TrainableDropoutGrad") != 0) {
test.AddInput<bool>("training_mode", {}, {true});
}
test.AddOutput<float>("dx", input_shape.GetDims(), dx_data);

View file

@ -45,7 +45,8 @@ REGISTER_TRAINABLE_KERNEL_TYPED(double, double)
.TypeConstraint("T", DataTypeImpl::GetTensorType<T1>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T2>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()) \
.InputMemoryType<OrtMemTypeCPUInput>(2), \
.InputMemoryType<OrtMemTypeCPUInput>(2) \
.InputMemoryType<OrtMemTypeCPUInput>(3), \
DropoutGrad<T1, T2>);
REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, MLFloat16)