From a24c71af4091fb3f61088e455fdf54bd3c780c2a Mon Sep 17 00:00:00 2001 From: "M. Zeeshan Siddiqui" Date: Mon, 4 May 2020 20:05:42 -0700 Subject: [PATCH] Update Dropout(12) forward kernel with training_mode input. (#3805) * Update Dropout(12) forward and backward kernel with training_mode input. * Revert deleted assert. * clean up. * PR feedback. --- .../training_ops/cpu/nn/dropout_op_test.cc | 17 +- .../training_ops/cpu/cpu_training_kernels.cc | 1 + .../training_ops/cpu/nn/dropout_op.cc | 88 +++++++---- .../training_ops/cpu/nn/dropout_op.h | 7 +- .../cpu/nn/trainable_dropout_op.cc | 58 ------- .../training_ops/cuda/nn/dropout.cc | 147 +++++++++++++----- .../training_ops/cuda/nn/dropout.h | 6 +- .../training_ops/cuda/nn/dropout_impl.cu | 24 +-- .../training_ops/cuda/nn/trainable_dropout.cc | 46 ------ 9 files changed, 192 insertions(+), 202 deletions(-) delete mode 100644 orttraining/orttraining/training_ops/cpu/nn/trainable_dropout_op.cc delete mode 100644 orttraining/orttraining/training_ops/cuda/nn/trainable_dropout.cc diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc index 8183078bd8..1348f370e0 100644 --- a/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc @@ -34,7 +34,7 @@ const Tensor& FetchTensor(const OrtValue& ort_value) { } void RunDropoutTest(const char* op, const bool use_mask, const std::vector& input_shape, float ratio = -1, - bool use_float16_ratio = false) { + bool training_mode = true, bool use_float16_ratio = false) { OpTester t{op, k_dropout_opset_version, kOnnxDomain}; const auto input_size = std::accumulate( @@ -47,12 +47,17 @@ void RunDropoutTest(const char* op, const bool use_mask, const std::vector("output", input_shape, input); // we'll do our own output verification std::unique_ptr mask_buffer{}; @@ -117,7 +122,7 @@ TEST(DropoutTest, Mask) { } TEST(DropoutTest, RatioLimit) { - RunDropoutTest("Dropout", true, {1000}, 0.0f); + RunDropoutTest("Dropout", true, {1000}, 0.0f, false); } TEST(DropoutTest, EmptyRatio) { @@ -125,7 +130,7 @@ TEST(DropoutTest, EmptyRatio) { } TEST(DropoutTest, Float16Ratio) { - RunDropoutTest("Dropout", true, {1000}, 0.0f, true); + RunDropoutTest("Dropout", true, {1000}, 0.0f, true, true); } TEST(TrainableDropoutTest, Basic) { @@ -137,15 +142,15 @@ TEST(TrainableDropoutTest, Mask) { } TEST(TrainableDropoutTest, RatioLimit) { - RunDropoutTest("TrainableDropout", true, {1000}, 0.0f); + RunDropoutTest("TrainableDropout", true, {1000}, 0.0f, false); } TEST(TrainableDropoutTest, EmptyRatio) { - RunDropoutTest("TrainableDropout", true, {1000}); + RunDropoutTest("TrainableDropout", true, {1000}, -1); } TEST(TrainableDropoutTest, Float16Ratio) { - RunDropoutTest("TrainableDropout", true, {1000}, 0.0f, true); + RunDropoutTest("TrainableDropout", true, {1000}, 0.0f, true, true); } namespace { diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index c8783baff9..69fef4b6c8 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -44,6 +44,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double_MLFloat16, TrainableDropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double_float, TrainableDropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double_double, TrainableDropout); + // REVIEW(mzs): ConstEigenVectorArrayMap.cast()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - OpName); + Dropout); + +// Temporary for backward compatibility, will eventually get rid of TrainableDropout when PyTorch exporter will move to +// opset-12. +REGISTER_KERNEL_TYPED(TrainableDropout, 9, float, MLFloat16, true) +REGISTER_KERNEL_TYPED(TrainableDropout, 9, float, float, true) +REGISTER_KERNEL_TYPED(TrainableDropout, 9, float, double, true) +REGISTER_KERNEL_TYPED(TrainableDropout, 9, double, MLFloat16, true) +REGISTER_KERNEL_TYPED(TrainableDropout, 9, double, float, true) +REGISTER_KERNEL_TYPED(TrainableDropout, 9, double, double, true) // REVIEW(mzs): ConstEigenVectorArrayMap.cast -Status Dropout::Compute(OpKernelContext* context) const { +template +Status Dropout::Compute(OpKernelContext* context) const { const Tensor* X = context->Input(0); auto X_span = X->DataAsSpan(); - const Tensor* ratio = context->Input(1); // optional const float ratio_value = GetRatioOrDefault(ratio); - const auto& X_shape = X->Shape(); - Tensor* Y = context->Output(0, X_shape); auto Y_span = Y->MutableDataAsSpan(); - Tensor* mask = context->Output(1, X_shape); // optional std::unique_ptr temp_mask_buffer{}; // temporary buffer to use if mask input is not provided auto mask_span = [&X_shape, mask, &temp_mask_buffer]() { @@ -75,16 +80,21 @@ Status Dropout::Compute(OpKernelContext* context) const { return gsl::make_span(temp_mask_buffer.get(), X_shape.Size()); }(); - ORT_ENFORCE(Y->Shape() == X_shape, "X and Y should have the same shape"); ORT_ENFORCE(!mask || mask->Shape() == X_shape, "X and mask should have the same shape"); - if (ratio_value == 0.0f) { + const Tensor* training_mode = context->Input(2); + if ((0 == ratio_value /*Backward compat with TrainableDropout*/) || + !trainable_dropout && (training_mode == nullptr || *(training_mode->Data()) == false)) { // drop none if (X_span.data() != Y_span.data()) { std::copy(X_span.begin(), X_span.end(), Y_span.begin()); } - std::fill(mask_span.begin(), mask_span.end(), true); - } else if (ratio_value < 1.0f) { + + if (mask != nullptr) { + std::fill(mask_span.begin(), mask_span.end(), true); + } + + } else { // drop some ConstEigenVectorArrayMap X_arr(X_span.data(), X_span.size()); EigenVectorArrayMap Y_arr(Y_span.data(), Y_span.size()); @@ -106,33 +116,51 @@ Status Dropout::Compute(OpKernelContext* context) const { return Status::OK(); } +#define REGISTER_GRADIENT_KERNEL_TYPED(OpName, T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + OpName, \ + kMSDomain, \ + 1, \ + T1##_##T2, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + DropoutGrad); + // DropoutGrad -// REVIEW(mzs): ConstEigenVectorArrayMap.cast Status DropoutGrad::Compute(OpKernelContext* context) const { const Tensor* dY = context->Input(0); auto dY_span = dY->DataAsSpan(); - const Tensor* mask = context->Input(1); auto mask_span = mask->DataAsSpan(); - const Tensor* ratio = context->Input(2); // optional const float ratio_value = GetRatioOrDefault(ratio); - const auto& dY_shape = dY->Shape(); - Tensor* dX = context->Output(0, dY_shape); auto dX_span = dX->MutableDataAsSpan(); diff --git a/orttraining/orttraining/training_ops/cpu/nn/dropout_op.h b/orttraining/orttraining/training_ops/cpu/nn/dropout_op.h index da102dc577..54d0d6f351 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/dropout_op.h +++ b/orttraining/orttraining/training_ops/cpu/nn/dropout_op.h @@ -9,8 +9,8 @@ namespace onnxruntime { namespace contrib { -template -class Dropout final : public OpKernel { +template +class Dropout final: public OpKernel { public: Dropout(const OpKernelInfo& info) : OpKernel{info} { int64_t seed = 0; @@ -28,7 +28,8 @@ class Dropout final : public OpKernel { template class DropoutGrad final : public OpKernel { public: - DropoutGrad(const OpKernelInfo& info) : OpKernel{info} {} + DropoutGrad(const OpKernelInfo& info) : OpKernel{info} { + } Status Compute(OpKernelContext* context) const override; }; diff --git a/orttraining/orttraining/training_ops/cpu/nn/trainable_dropout_op.cc b/orttraining/orttraining/training_ops/cpu/nn/trainable_dropout_op.cc deleted file mode 100644 index d66771954f..0000000000 --- a/orttraining/orttraining/training_ops/cpu/nn/trainable_dropout_op.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cpu/nn/dropout_op.h" -#include -#include "core/util/math_cpuonly.h" - -namespace onnxruntime { -namespace contrib { - -// TrainableDropout is the same as Dropout V12. -// Registering the operator for the sake of backward compatibility. -// Give notice to the users to use Dropout V12 and then deprecate this kernel. - -// TrainableDropout -#define REGISTER_KERNEL_TYPED(OpName, Domain, VER, T1, T2, ClassName) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - OpName, \ - Domain, \ - VER, \ - T1##_##T2, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - ClassName); - -// REVIEW(mzs): ConstEigenVectorArrayMap.cast()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(MemIndex), \ - OpName); +#define REGISTER_KERNEL_TYPED(T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Dropout, \ + kOnnxDomain, \ + 12, \ + T1##_##T2, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(1) \ + .InputMemoryType(2), \ + Dropout); -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, MLFloat16, MLFloat16, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, MLFloat16, float, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, MLFloat16, double, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, float, MLFloat16, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, float, float, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, float, double, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, MLFloat16, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, float, 1) -REGISTER_KERNEL_TYPED(Dropout, kOnnxDomain, 12, double, double, 1) +REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16) +REGISTER_KERNEL_TYPED(MLFloat16, float) +REGISTER_KERNEL_TYPED(MLFloat16, double) +REGISTER_KERNEL_TYPED(float, MLFloat16) +REGISTER_KERNEL_TYPED(float, float) +REGISTER_KERNEL_TYPED(float, double) +REGISTER_KERNEL_TYPED(double, MLFloat16) +REGISTER_KERNEL_TYPED(double, float) +REGISTER_KERNEL_TYPED(double, double) +#define REGISTER_TRAINABLE_KERNEL_TYPED(T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + TrainableDropout, \ + kOnnxDomain, \ + 9, \ + T1##_##T2, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(1), \ + Dropout); -template -Status Dropout::ComputeInternal(OpKernelContext* context) const { +// Temporary for backward compatibility, will eventually get rid of TrainableDropout when PyTorch exporter will move to +// opset-12. +REGISTER_TRAINABLE_KERNEL_TYPED(MLFloat16, MLFloat16) +REGISTER_TRAINABLE_KERNEL_TYPED(MLFloat16, float) +REGISTER_TRAINABLE_KERNEL_TYPED(MLFloat16, double) +REGISTER_TRAINABLE_KERNEL_TYPED(float, MLFloat16) +REGISTER_TRAINABLE_KERNEL_TYPED(float, float) +REGISTER_TRAINABLE_KERNEL_TYPED(float, double) +REGISTER_TRAINABLE_KERNEL_TYPED(double, MLFloat16) +REGISTER_TRAINABLE_KERNEL_TYPED(double, float) +REGISTER_TRAINABLE_KERNEL_TYPED(double, double) + +template +Status Dropout::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; //Get X_data @@ -52,12 +77,6 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { //Get mask_data auto mask = context->Output(1, shape); ORT_ENFORCE(!mask || mask->Shape().Size() == N); - IAllocatorUniquePtr temp_mask_buffer{}; // buffer to use if mask is not provided - bool* const mask_data = [this, N, mask, &temp_mask_buffer]() { - if (mask) return mask->MutableData(); - temp_mask_buffer = GetScratchBuffer(N); - return temp_mask_buffer.get(); - }(); //Get the ratio_data float ratio_data; @@ -73,21 +92,70 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { } ORT_ENFORCE(ratio_data >= 0.0f && ratio_data < 1.0f); + const Tensor* training_mode = context->Input(2); + //Check for inference mode. + if ((0 == ratio_data /*Backward compat with TrainableDropout*/) || + (!trainable_dropout && (training_mode == nullptr || *(training_mode->Data()) == false))) { + if (Y_data != X_data) { + CUDA_CALL_THROW(cudaMemcpyAsync(Y_data, X_data, N * sizeof(T1), cudaMemcpyDeviceToDevice)); + } + + // If mask is requested, return all 1s. + if (mask != nullptr) { + ORT_ENFORCE(cudaMemset(mask->MutableData(), true, N * sizeof(bool)) == cudaSuccess); + } + + return Status::OK(); + } + + IAllocatorUniquePtr temp_mask_buffer{}; // buffer to use if mask is not provided + bool* const mask_data = [this, N, mask, &temp_mask_buffer]() { + if (mask) return mask->MutableData(); + temp_mask_buffer = GetScratchBuffer(N); + return temp_mask_buffer.get(); + }(); + PhiloxGenerator& generator = generator_ != nullptr ? *generator_.get() : PhiloxGenerator::Default(); DropoutKernelImpl(GetDeviceProp(), N, ratio_data, generator, X_data, Y_data, mask_data); return Status::OK(); } -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, MLFloat16, MLFloat16, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, MLFloat16, float, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, MLFloat16, double, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, float, MLFloat16, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, float, float, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, float, double, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, double, MLFloat16, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, double, float, 2) -REGISTER_KERNEL_TYPED(DropoutGrad, kMSDomain, 1, double, double, 2) +#define REGISTER_GRADIENT_KERNEL_TYPED(OpName, T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + OpName, \ + kMSDomain, \ + 1, \ + T1##_##T2, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(2), \ + DropoutGrad); + +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, MLFloat16) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, float) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, double) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, float, MLFloat16) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, float, float) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, float, double) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, double, MLFloat16) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, double, float) +REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, double, double) + +// Temporary for backward compatibility, will eventually get rid of TrainableDropout when PyTorch exporter will move to +// opset-12. +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, MLFloat16, MLFloat16) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, MLFloat16, float) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, MLFloat16, double) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, float, MLFloat16) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, float, float) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, float, double) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, double, MLFloat16) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, double, float) +REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, double, double) template Status DropoutGrad::ComputeInternal(OpKernelContext* context) const { @@ -103,7 +171,6 @@ Status DropoutGrad::ComputeInternal(OpKernelContext* context) const { auto dX = context->Output(0, shape); auto dX_data = reinterpret_cast(dX->template MutableData()); - float ratio_data; auto ratio = context->Input(2); diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout.h b/orttraining/orttraining/training_ops/cuda/nn/dropout.h index f808419e82..dca18128d0 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout.h +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace cuda { -template +template class Dropout final : public CudaKernel { public: Dropout(const OpKernelInfo& info) : CudaKernel(info), default_ratio_(0.5) { @@ -29,7 +29,9 @@ class Dropout final : public CudaKernel { template class DropoutGrad final : public CudaKernel { public: - DropoutGrad(const OpKernelInfo& info) : CudaKernel(info), default_ratio_(0.5) {} + DropoutGrad(const OpKernelInfo& info) : CudaKernel(info), default_ratio_(0.5) { + } + Status ComputeInternal(OpKernelContext* context) const override; private: diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu index cc295f9a5f..3bbb08ba08 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu @@ -21,9 +21,6 @@ #include #include -#include "thrust/device_ptr.h" -#include "thrust/fill.h" - namespace onnxruntime { namespace cuda { @@ -77,22 +74,15 @@ void DropoutKernelImpl( const T* X_data, T* Y_data, bool* mask_data) { - if (ratio == 0.0f) { - if (Y_data != X_data) { - CUDA_CALL_THROW(cudaMemcpyAsync(Y_data, X_data, N * sizeof(T), cudaMemcpyDeviceToDevice)); - } - thrust::fill_n(thrust::device_pointer_cast(mask_data), N, true); - } else { - const int block_size = 256; - const int blocks_per_sm = prop.maxThreadsPerMultiProcessor / block_size; - const int grid_size = std::min(prop.multiProcessorCount * blocks_per_sm, static_cast(CeilDiv(N, block_size))); + const int block_size = 256; + const int blocks_per_sm = prop.maxThreadsPerMultiProcessor / block_size; + const int grid_size = std::min(prop.multiProcessorCount * blocks_per_sm, static_cast(CeilDiv(N, block_size))); - // Compute the number of random numbers generated by each thread, and increment philox generator offset by that amount. - const uint64_t counter_offset = static_cast(((N - 1) / (block_size * grid_size * UNROLL) + 1) * UNROLL); - auto seeds = generator.NextPhiloxSeeds(counter_offset); + // Compute the number of random numbers generated by each thread, and increment philox generator offset by that amount. + const uint64_t counter_offset = static_cast(((N - 1) / (block_size * grid_size * UNROLL) + 1) * UNROLL); + auto seeds = generator.NextPhiloxSeeds(counter_offset); - DropoutKernel<<>>(N, ratio, seeds, X_data, Y_data, mask_data); - } + DropoutKernel<<>>(N, ratio, seeds, X_data, Y_data, mask_data); } #define SPECIALIZED_DROPOUT_IMPL(T) \ diff --git a/orttraining/orttraining/training_ops/cuda/nn/trainable_dropout.cc b/orttraining/orttraining/training_ops/cuda/nn/trainable_dropout.cc deleted file mode 100644 index 54ddd94a11..0000000000 --- a/orttraining/orttraining/training_ops/cuda/nn/trainable_dropout.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cuda/nn/dropout.h" - -#include "core/providers/common.h" - -namespace onnxruntime { -namespace cuda { - -#define REGISTER_KERNEL_TYPED(OpName, Domain, VER, T1, T2, MemIndex, ClassName) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - OpName, \ - Domain, \ - VER, \ - T1##_##T2, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(MemIndex), \ - ClassName); - -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, MLFloat16, MLFloat16, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, MLFloat16, float, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, MLFloat16, double, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, float, MLFloat16, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, float, float, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, float, double, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, double, MLFloat16, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, double, float, 1, Dropout) -REGISTER_KERNEL_TYPED(TrainableDropout, kOnnxDomain, 9, double, double, 1, Dropout) - -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, MLFloat16, MLFloat16, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, MLFloat16, float, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, MLFloat16, double, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, float, MLFloat16, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, float, float, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, float, double, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, double, MLFloat16, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, double, float, 2, DropoutGrad) -REGISTER_KERNEL_TYPED(TrainableDropoutGrad, kMSDomain, 1, double, double, 2, DropoutGrad) - -} // namespace cuda -} // namespace onnxruntime