From 81101c9efd15a69342dec94362cf7c29a4e41be9 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 4 Jun 2020 15:00:02 -0700 Subject: [PATCH] Fix DropoutGrad op (#4052) Dropout op was recently changed to accept a new input named 'training_mode', which is passed in to DropoutGrad automatically. This PR updates the DropoutGrad schema to accommodate the new input. Tests were also update to reflect the API change Co-authored-by: Thiago Crepaldi --- .../orttraining/core/graph/gradient_schema_defs.cc | 8 +++++++- .../orttraining/test/gradient/gradient_ops_test.cc | 4 ++++ .../test/training_ops/cpu/nn/dropout_op_test.cc | 6 ++++++ orttraining/orttraining/training_ops/cuda/nn/dropout.cc | 3 ++- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index 0682bdf6ef..e3bc26cf6b 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -1078,6 +1078,12 @@ Example 4: "the case during training.", "T1", OpSchema::Optional) + .Input(3, "training_mode", + "If set to true then it indicates dropout is being used for training. It is an optional value hence unless " + "specified explicitly, it is false. If it is false, ratio is ignored and the operation mimics inference mode where " + "nothing will be dropped from the input data and if mask is requested as output it will contain all ones.", + "T2", + OpSchema::Optional) .Output(0, "dx", "Gradient of the input.", "T") .TypeConstraint( "T", @@ -1090,7 +1096,7 @@ Example 4: .TypeConstraint( "T2", {"tensor(bool)"}, - "Constrain 'mask' types to boolean tensors.") + "Constrain 'mask' and 'training_mode' types to boolean tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateShapeAndTypeFromFirstInput(ctx); }); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 5b22f22b7f..f9c88ef415 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1492,8 +1492,12 @@ void TestDropoutGradOp(float ratio, TensorShape& x_shape, bool default_ratio = t true, false, true, false}); if (!default_ratio) { test.AddInput("ratio", {1}, ratio_data); + } else { + test.AddMissingOptionalInput(); } + test.AddInput("training_mode", {}, {true}); + test.AddOutput("dx", x_shape.GetDims(), dx_data); test.Run(); diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc index 4d6eef7b52..da3ba31a01 100644 --- a/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc @@ -176,6 +176,12 @@ void RunDropoutGradTest(const char* op, float ratio, const std::vector& test.AddInput("mask", input_shape.GetDims(), mask_buffer.get(), input_shape.Size()); if (!default_ratio) { test.AddInput("ratio", {1}, ratio_data); + } else { + test.AddMissingOptionalInput(); + } + + if (strcmp(op, "TrainableDropoutGrad") != 0) { + test.AddInput("training_mode", {}, {true}); } test.AddOutput("dx", input_shape.GetDims(), dx_data); diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc index 54cdddeb7f..fd84b6f3bb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc @@ -45,7 +45,8 @@ REGISTER_TRAINABLE_KERNEL_TYPED(double, double) .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(2), \ + .InputMemoryType(2) \ + .InputMemoryType(3), \ DropoutGrad); REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, MLFloat16)