Add native_dropout (#63937)

Summary:
Adds native_dropout to have a reasonable target for torchscript in auto diff. native_dropout has scale and train as arguments in its signature, this makes native_dropout more consistent with other operators and removes conditionals in the autodiff definition.

cc gmagogsfm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63937

Reviewed By: mruberry

Differential Revision: D32477657

Pulled By: ngimel

fbshipit-source-id: d37b137a37acafa50990f60c77f5cea2818454e4
This commit is contained in:
jiej 2021-11-18 19:39:53 -08:00 committed by Facebook GitHub Bot
parent a39060c001
commit ca92111758
13 changed files with 191 additions and 80 deletions

View file

@ -11,6 +11,7 @@ TORCH_LIBRARY_IMPL(_, Named, m) {
TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("_cdist_forward", CppFunction::makeFallthrough());
m.impl("_fused_dropout", CppFunction::makeFallthrough());
m.impl("native_dropout", CppFunction::makeFallthrough());
m.impl("_local_scalar_dense", CppFunction::makeFallthrough());
m.impl("_sparse_log_softmax.Dimname", CppFunction::makeFallthrough());
m.impl("_sparse_log_softmax.int", CppFunction::makeFallthrough());

View file

@ -3,7 +3,8 @@
#include <ATen/NamedTensorUtils.h>
#include <c10/util/irange.h>
namespace at { namespace native {
namespace at {
namespace native {
namespace {
@ -85,11 +86,39 @@ ALIAS_SPECIALIZATION(_feature_alpha_dropout, true, true )
} // anomymous namepsace
std::tuple<Tensor,Tensor>
native_dropout_cpu(const Tensor& input, double p, c10::optional<bool> train) {
if (input.numel() == 0) {
return std::make_tuple(input, at::empty_like(input, input.options()));
}
Tensor mask;
Tensor output;
if (!train.has_value() || *train) {
double p1m = 1. - p;
// Check for probability of zero to avoid divide by zero and NaN results
double scale = p1m == 0 ? 0. : 1. / p1m;
mask = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
mask.bernoulli_(p1m);
output = input.mul(mask).mul_(scale);
} else {
mask = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
output = input.clone();
}
return std::make_tuple(output, mask);
}
Tensor native_dropout_backward_cpu(const Tensor& grad, const Tensor& mask, double scale) {
Tensor result = grad * mask * scale;
return result;
}
Tensor dropout(const Tensor& input, double p, bool train) {
auto result = [&]() {
NoNamesGuard guard;
if (train && is_fused_kernel_acceptable(input, p)) {
return std::get<0>(at::_fused_dropout(input, 1 - p));
return std::get<0>(at::native_dropout(input, p, train));
}
return _dropout<false>(input, p, train);
}();
@ -125,4 +154,5 @@ Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) {
return _feature_alpha_dropout<true>(input, p, train);
}
}} // namespace at::native
} // namespace native
} // namespace at

View file

@ -26,22 +26,22 @@ template <
typename accscalar_t,
typename IndexType,
int ADims,
int VEC>
int VEC,
typename mask_t>
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void fused_dropout_kernel_vec(
at::cuda::detail::TensorInfo<scalar_t, IndexType> a,
at::cuda::detail::TensorInfo<scalar_t, IndexType> b,
at::cuda::detail::TensorInfo<uint8_t, IndexType> c,
IndexType totalElements,
accscalar_t p,
PhiloxCudaState philox_args) {
__global__ void
fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,
at::cuda::detail::TensorInfo<scalar_t, IndexType> b,
at::cuda::detail::TensorInfo<mask_t, IndexType> c,
IndexType totalElements, accscalar_t p,
PhiloxCudaState philox_args) {
// make sure we don't break assumption that we can't have > 4 elements / thread
static_assert(VEC <= 4, "Value of VEC must be in [2, 4]");
using LoadT = memory::aligned_vector<scalar_t, VEC>;
using MaskLoadT = memory::aligned_vector<uint8_t, VEC>;
using MaskLoadT = memory::aligned_vector<mask_t, VEC>;
auto seeds = at::cuda::philox::unpack(philox_args);
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
@ -51,11 +51,10 @@ __global__ void fused_dropout_kernel_vec(
std::get<1>(seeds),
&state);
accscalar_t pinv = accscalar_t(1)/p;
// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
// in the vec=2 and vec=4 cases.
bool gridxvec_loop_state = 0;
accscalar_t scale = 1.0 / p;
float4 rand;
@ -94,13 +93,13 @@ __global__ void fused_dropout_kernel_vec(
*value = *reinterpret_cast<LoadT*>(&a.data[linearIndex]);
scalar_t r[VEC];
uint8_t mask[VEC];
mask_t mask[VEC];
// Perform the actual computation
#pragma unroll
for (int ii = 0; ii < VEC; ii++) {
r[ii] = src[ii]*(&rand.x)[ii]*pinv;
mask[ii] = (uint8_t)(&rand.x)[ii];
r[ii] = src[ii]*(&rand.x)[ii]*scale;
mask[ii] = (mask_t)(&rand.x)[ii];
}
// Vectorized writes for both mask & result
*(reinterpret_cast<LoadT*>(&b.data[linearIndex])) = *reinterpret_cast<LoadT*>(&r[0]);
@ -115,17 +114,17 @@ template <
typename accscalar_t,
typename IndexType,
int ADims,
int BDims = ADims>
int BDims = ADims,
typename mask_t>
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void fused_dropout_kernel(
cuda::detail::TensorInfo<scalar_t, IndexType> a,
cuda::detail::TensorInfo<scalar_t, IndexType> b,
cuda::detail::TensorInfo<uint8_t, IndexType> c,
IndexType totalElements,
accscalar_t p,
PhiloxCudaState philox_args) {
__global__ void
fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
cuda::detail::TensorInfo<scalar_t, IndexType> b,
cuda::detail::TensorInfo<mask_t, IndexType> c,
IndexType totalElements, accscalar_t p,
PhiloxCudaState philox_args) {
auto seeds = at::cuda::philox::unpack(philox_args);
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
@ -133,8 +132,7 @@ __global__ void fused_dropout_kernel(
idx,
std::get<1>(seeds),
&state);
accscalar_t pinv = accscalar_t(1)/p;
accscalar_t scale = 1.0 / p;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
blockDim.x * gridDim.x * UNROLL;
@ -163,15 +161,15 @@ __global__ void fused_dropout_kernel(
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
cuda::detail::IndexToOffset<scalar_t, IndexType, BDims>::get(li, b);
b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv;
c.data[bOffset] = (uint8_t)(&rand.x)[ii];
b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale;
c.data[bOffset] = (mask_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}
template<typename scalar_t, typename accscalar_t>
template<typename mask_t, typename scalar_t, typename accscalar_t>
void masked_scale_kernel(at::Tensor& ret, const at::Tensor& src, const at::Tensor& mask, accscalar_t scale){
auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
@ -182,7 +180,7 @@ void masked_scale_kernel(at::Tensor& ret, const at::Tensor& src, const at::Tenso
at::native::gpu_kernel(
iter,
[=]GPU_LAMBDA(const scalar_t src_val, const uint8_t mask_val) -> scalar_t {
[=]GPU_LAMBDA(const scalar_t src_val, const mask_t mask_val) -> scalar_t {
return (float)mask_val * src_val * scale;
});
}
@ -206,7 +204,7 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
return can_vectorize ? vec_size : 1;
}
template <typename index_type>
template <typename index_type, typename mask_t>
inline void launcher(
const Tensor& self,
Tensor& ret,
@ -229,7 +227,7 @@ inline void launcher(
auto ret_info =
cuda::detail::getTensorInfo<scalar_t, index_type>(ret);
auto mask_info =
cuda::detail::getTensorInfo<uint8_t, index_type>(mask);
cuda::detail::getTensorInfo<mask_t, index_type>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); // ret and mask are collapsed to 1d
@ -321,14 +319,16 @@ inline void launcher(
} //anonymous namespace
template <typename mask_t>
std::tuple<Tensor,Tensor>
fused_dropout_cuda(const Tensor& self, double p, c10::optional<Generator> gen_){
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
Tensor ret = at::empty_like(self);
Tensor mask = at::empty_like(self, self.options().dtype(kByte));
dropout_cuda(CUDAGeneratorImpl* gen, const Tensor& self, double p){
Tensor mask = at::empty_like(self, self.options().dtype(c10::CppTypeToScalarType<mask_t>::value));
const int64_t nelem = self.numel();
//empty tensors should not get here, but just in case, avoid FPE
if (nelem==0) return std::tuple<Tensor,Tensor>(self, mask);
// empty tensors should not get here, but just in case, avoid FPE
// non-training shot-cut
if (nelem==0) return std::tuple<Tensor,Tensor>(self.clone(), mask);
Tensor ret = at::empty_like(self);
const int64_t block_size = 256;
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
dim3 dim_block(block_size);
@ -343,25 +343,62 @@ fused_dropout_cuda(const Tensor& self, double p, c10::optional<Generator> gen_){
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
}
if (cuda::detail::canUse32BitIndexMath(self)){
launcher<unsigned int>(
launcher<unsigned int, mask_t>(
self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
} else {
launcher<uint64_t>(
launcher<uint64_t, mask_t>(
self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
}
return std::tuple<Tensor,Tensor>(ret, mask);
}
Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
Tensor ret = at::empty_like(self, self.suggest_memory_format());
TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
std::tuple<Tensor,Tensor>
native_dropout_cuda(const Tensor& self, double p, c10::optional<bool> train){
// short-cut for train == false
if (train.has_value() && !train.value()) {
return std::make_tuple(self.clone(), at::ones_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value)));
}
// short-cut
if (p == 1) {
// native_dropout_cuda is in derivatives.yaml, so we don't need to add data
// dependency from output to input for autograd
auto ret = at::zeros_like(self);
auto mask = at::zeros_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value));
return std::tuple<Tensor,Tensor>(ret, mask);
}
auto gen = get_generator_or_default<CUDAGeneratorImpl>(c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
double p1m = 1. - p;
return dropout_cuda<bool>(gen, self, p1m);
}
// TODO: _fused_dropout_cuda is to be removed, see PR #63937
std::tuple<Tensor,Tensor>
fused_dropout_cuda(const Tensor& self, double p, c10::optional<Generator> gen_){
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
return dropout_cuda<uint8_t>(gen, self, p);
}
template <typename mask_t>
Tensor dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
Tensor ret = at::empty_like(grad, grad.suggest_memory_format());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(scale);
masked_scale_kernel<scalar_t>(ret, self, mask, pa);
masked_scale_kernel<mask_t, scalar_t>(ret, grad, mask, (accscalar_t)scale);
});
return ret;
}
Tensor native_dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Mask should be Bool Scalar Type", mask.scalar_type());
return dropout_backward_cuda<bool>(grad, mask, scale);
}
// TODO: masked_scale_cuda is to be removed, see PR #63937
Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
return dropout_backward_cuda<uint8_t>(self, mask, scale);
}
}
}

View file

@ -176,6 +176,17 @@
dispatch:
CUDA: masked_scale_cuda
- func: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
variants: function
dispatch:
CPU: native_dropout_cpu
CUDA: native_dropout_cuda
- func: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor
dispatch:
CPU: native_dropout_backward_cpu
CUDA: native_dropout_backward_cuda
- func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor)
- func: _sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!)

View file

@ -1650,6 +1650,30 @@ graph(%Ra, %Rb):
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
@unittest.skipIf(not RUN_CUDA, "test requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
def test_native_dropout_corner_case(self):
with disable_autodiff_subgraph_inlining():
def t(x, p: float, t: bool):
o = torch.dropout(x, p, t)
return o
jit_t = torch.jit.script(t)
x = torch.randn(5).requires_grad_()
FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True))
for train in [True, False]:
for p in [0.0, 1.0]:
for device in ["cuda", "cpu"]:
x = torch.randn(5).to(device=device).requires_grad_()
x_ref = x.detach().requires_grad_()
o = jit_t(x, p, train)
o_ref = t(x_ref, p, train)
o.sum().backward()
o_ref.sum().backward()
assert(o.equal(o_ref))
assert(x.grad.equal(x_ref.grad))
@slowTest
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
def test_dropout_module_requires_grad(self):
@ -1690,7 +1714,7 @@ graph(%Ra, %Rb):
for requires_grad in (True, False):
X = torch.randn(M, M, requires_grad=requires_grad)
if requires_grad:
FileCheck().check("aten::bernoulli_").run(scripted.graph_for(X, profile_and_replay=True))
FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True))
self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X))
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph')
@ -1714,7 +1738,7 @@ graph(%Ra, %Rb):
for requires_grad in (True, False):
X = torch.randn(M, M, requires_grad=requires_grad)
if requires_grad:
FileCheck().check("aten::bernoulli_").run(scripted_training.graph_for(X, profile_and_replay=True))
FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True))
self.assertIn('aten::bernoulli_', profile(scripted_training, X))
self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X))

View file

@ -6683,6 +6683,20 @@ class TestNN(NNTestCase):
bad_input = torch.randn(3, 1)
test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_native_dropout_corner_case(self):
for train in [True, False]:
for p in [0.0, 1.0]:
for device in ["cuda", "cpu"]:
x = torch.randn(5).to(device=device).requires_grad_()
x_ref = x.detach().requires_grad_()
o = torch.native_dropout(x, p, train)[0]
o_ref = torch.dropout(x_ref, p, train)
o.sum().backward()
o_ref.sum().backward()
assert(o.equal(o_ref))
assert(x.grad.equal(x_ref.grad))
def test_invalid_dropout_p(self):
v = torch.ones(1)
self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))

View file

@ -534,6 +534,13 @@
- name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
self: _fused_dropout_backward(grad, result1, p)
- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))"
- name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor
grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)"
mask: 'not_implemented("native_dropout_backward: mask")'
- name: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)
self: eig_backward(grads, self, eigenvectors, eigenvalues, eigenvectors_return)

View file

@ -927,6 +927,15 @@ Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
}
}
// scale == (1 / (1 - prob))
Tensor infinitely_differentiable_native_dropout_backward(const Tensor& grad, const Tensor& mask, double scale) {
return grad * (mask.type_as(grad) * scale);
}
Tensor native_dropout_double_backward(const Tensor& ggI, const Tensor& grad, const Tensor& mask, double scale) {
return ggI.type_as(grad) * (mask.type_as(grad) * scale);
}
Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) {
if (input.is_cuda()) {
auto mask = (input == value).logical_or_(input.isnan().logical_and_(value.isnan()));

View file

@ -92,6 +92,8 @@ at::Tensor sparse_sparse_matmul_backward(const at::Tensor& grad, const at::Tenso
at::Tensor renorm_backward(const at::Tensor & grad, const at::Tensor & self, const at::Scalar& p, int64_t dim, const at::Scalar& maxnorm);
at::Tensor repeat_backward(at::Tensor grad, at::IntArrayRef repeats, at::IntArrayRef input_shape);
at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, double p1m);
at::Tensor infinitely_differentiable_native_dropout_backward(const at::Tensor& grad, const at::Tensor& mask, double scale);
at::Tensor native_dropout_double_backward(const at::Tensor& ggI, const at::Tensor& grad, const at::Tensor& mask, double scale);
at::Tensor evenly_distribute_backward(at::Tensor grad, const at::Tensor & input, const at::Tensor & value);
at::Tensor sgn_backward(Tensor result, Tensor grad, Tensor self);
at::Tensor var_backward(at::Tensor grad, const at::Tensor& self, c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction, bool keepdim);

View file

@ -1142,6 +1142,7 @@ bool Node::isNondeterministic() const {
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",

View file

@ -1174,40 +1174,16 @@ const std::vector<std::string> functions = {
return grad_input, None, grad_weight, grad_bias, None, None
return output, backward
def AD_fused_dropout_backward(grad,
mask,
p1m: float):
p1r = 1. / p1m
grad_input = grad * (mask.type_as(grad) * p1r)
return grad_input
def dropout(input,
p: float,
train: bool):
use_cuda = input.is_cuda
# lowering is specialized for cuda because cuda fuser can efficiently fuse those operations
# for cpu backend, where fusions are disabled, a different lowering that is more efficient
# in the absence of fusion is used
p1m = 1. - p
if train:
if use_cuda:
mask = torch.rand_like(input, memory_format=1) < p1m
res = mask.type_as(input) * input * (1./p1m)
else:
mask = torch.empty_like(input, memory_format=1)
mask.bernoulli_(p1m)
res = mask * input / p1m
else:
p1m = 1.
res = input
mask = torch.empty_like(input, memory_format=1)
# if `train == false` we need to set `p1m` to 0 so `scale == 1`
p1m = (1. - p) * float(train)
scale = 1. / (float(p1m == 0.) + p1m)
res,mask = torch.native_dropout(input, p, train)
def backward(grad_output):
use_cuda = grad_output.is_cuda
if use_cuda:
grad_input = AD_fused_dropout_backward(grad_output, mask, p1m)
else:
grad_input = grad_output * mask / p1m
grad_input = torch.native_dropout_backward(grad_output, mask, scale)
return grad_input, None, None
return res, backward

View file

@ -664,6 +664,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.narrow_copy: lambda input, dim, start, length: -1,
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
torch.native_dropout : lambda input, p, train: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
torch.native_norm: lambda input, p=2: -1,

View file

@ -68,9 +68,7 @@ nn_functional_tests = [
('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
('dropout', (S, S, S), (0.5,), '', (True,
['aten::bernoulli_',
'aten::empty_like', 'aten::mul', 'aten::div'])),
('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
('alpha_dropout', (S, S, S), (0.5,)),
('dropout2d', (S, S, S), (0.5,)),
('dropout3d', (S, S, S), (0.5,)),