mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Bug fix for shape of optional output in Dropout op (#1507)
* Bug fix for shape of optional output in Dropout op * Exclude new test from NGraph EP * Account for the fact that mask could be of different type in different opset variants of the op * Make accompanying Cuda changes * Fix build break * Exclude Opset 7 test for tensorRT EP * PR comments
This commit is contained in:
parent
57e2482089
commit
465b30e3ca
5 changed files with 69 additions and 6 deletions
|
|
@ -43,7 +43,18 @@ class IdentityOp final : public OpKernel {
|
|||
}
|
||||
|
||||
if (is_dropout) {
|
||||
context->Output(1, std::vector<int64_t>());
|
||||
Tensor* mask = context->Output(1, shape);
|
||||
// a 'nullptr' returned would make it an unused optional output
|
||||
if (mask != nullptr) {
|
||||
// Opset 7 differs with Opset 10 in that the type of the 'mask'
|
||||
// output is tied with the type of the input in Opset 7 whereas
|
||||
// the type of 'mask' in Opset 10 is 'bool' always
|
||||
// so we have a common solution
|
||||
void* mask_data = mask->MutableDataRaw();
|
||||
// In 'test'/'inference' mode, there are no input values dropped out
|
||||
// so fill the buffer with 0/false
|
||||
memset(mask_data, 0, mask->SizeInBytes());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Un
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, Flatten);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Squeeze);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Identity);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, Dropout);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, Dropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Gather);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
|
||||
|
|
@ -515,6 +515,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Shrink);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Shrink);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout);
|
||||
|
||||
static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -525,7 +526,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Squeeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Identity)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Gather)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm)>,
|
||||
|
|
@ -840,6 +841,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -5,13 +5,28 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Dropout,
|
||||
kOnnxDomain,
|
||||
7, 9,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>()})
|
||||
.Alias(0, 0),
|
||||
IdentityOp<true>);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Dropout,
|
||||
kOnnxDomain,
|
||||
7,
|
||||
10,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(), DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()})
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>()})
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>())
|
||||
.Alias(0, 0),
|
||||
IdentityOp<true>);
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,18 @@ class IdentityOp final : public CudaKernel {
|
|||
}
|
||||
|
||||
if (is_dropout) {
|
||||
context->Output(1, std::vector<int64_t>());
|
||||
Tensor* mask = context->Output(1, shape);
|
||||
// a 'nullptr' returned would make it an unused optional output
|
||||
if (mask != nullptr) {
|
||||
// Opset 7 differs with Opset 10 in that the type of the 'mask'
|
||||
// output is tied with the type of the input in Opset 7 whereas
|
||||
// the type of 'mask' in Opset 10 is 'bool' always
|
||||
// so we have a common solution
|
||||
void* mask_data = mask->MutableDataRaw();
|
||||
// In 'test'/'inference' mode, there are no input values dropped out
|
||||
// so fill the buffer with 0/false
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_data, 0, mask->SizeInBytes()));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -23,5 +23,29 @@ TEST(Dropout, Opset10) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Dropout, WithOptionalOutputOpset10) {
|
||||
OpTester test("Dropout", 10, kOnnxDomain);
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims, {1.0f, 2.0f, 3.0f, 5.0f});
|
||||
test.AddOutput<float>("Y", dims, {1.0f, 2.0f, 3.0f, 5.0f});
|
||||
test.AddOutput<bool>("mask", dims, {false, false, false, false});
|
||||
// The NGraph execution provider doesn't seem to support 'Dropout' with optional mask output
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(Dropout, WithOptionalOutputOpset7) {
|
||||
// Opset 7 differs with Opset 10 in that the type of the 'mask'
|
||||
// output is tied with the type of the input in Opset 7 whereas
|
||||
// the type of 'mask' in Opset 10 is 'bool' always
|
||||
OpTester test("Dropout", 7, kOnnxDomain);
|
||||
std::vector<int64_t> dims{2, 2};
|
||||
test.AddInput<float>("X", dims, {1.0f, 2.0f, 3.0f, 5.0f});
|
||||
test.AddOutput<float>("Y", dims, {1.0f, 2.0f, 3.0f, 5.0f});
|
||||
test.AddOutput<float>("mask", dims, {0.0f, 0.0f, 0.0f, 0.0f});
|
||||
// The NGraph execution provider doesn't seem to support 'Dropout' with optional mask output
|
||||
// The TensorRT execution provider doesn't seem to support 'Dropout' with non-boolean mask output
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue