Structured kernels generate Meta registrations (#48116)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48116

If you port kernels to be structured, you get Meta kernels automatically
generated for you.  This is one payoff of structured kernels.

Code generation was mercifully really simple, although at risk of
"swiss cheese" syndrome: there's two new conditionals in the codegen
to tweak behavior when generating for meta keys.  It's not too bad
right now but there's a risk of things getting out of hand.  One
way to rationalize the logic here would be to transmit "TensorMeta-ness"
inside the TensorOptions (so tensor_from_meta can deal with it); then
the "Meta" kernel magic would literally just be generating empty
out_impls to call after all the scaffolding is done.  But I didn't
do this because it seemed like it would be more annoying short term.

Also had to teach resize_ to work on meta tensors, since we use them
to implement the out kernels.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: bhosmer, ailzhang

Differential Revision: D25056640

Pulled By: ezyang

fbshipit-source-id: f8fcfa0dbb58a94d9b4196748f56e155f83b1521
This commit is contained in:
Edward Yang 2020-12-02 07:47:13 -08:00 committed by Facebook GitHub Bot
parent 47db191f0c
commit b4f5efa7b2
7 changed files with 67 additions and 12 deletions

View file

@ -131,6 +131,7 @@ genrule(
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterMath.cpp",
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterDefaultBackend.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/Functions.h",

View file

@ -14,6 +14,11 @@ struct TensorMeta {
: sizes(_sizes), options(_options) {}
};
inline Tensor meta_tensor_from_meta(const TensorMeta& meta) {
// TODO: eliminate indirection
return at::empty_meta(meta.sizes, meta.options);
}
inline Tensor tensor_from_meta(const TensorMeta& meta) {
// TODO: eliminate indirection
return at::empty(meta.sizes, meta.options);

View file

@ -77,12 +77,13 @@ Tensor& resize_as_(
Tensor& resize_(
Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
c10::optional<MemoryFormat> optional_memory_format,
bool resize_storage) {
if (self.has_names()) {
return resize_named_tensor_(self, size, optional_memory_format);
}
auto* self_ = self.unsafeGetTensorImpl();
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
@ -95,5 +96,20 @@ Tensor& resize_(
return self;
}
Tensor& resize_(
Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
return resize_(self, size, optional_memory_format, /*resize_storage=*/true);
}
Tensor& resize_meta_(
Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
// meta tensors don't have storage, so don't resize them
return resize_(self, size, optional_memory_format, /*resize_storage=*/false);
}
} // namespace native
} // namespace at

View file

@ -43,7 +43,8 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size)
inline TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntArrayRef size,
c10::optional<IntArrayRef> stride) {
c10::optional<IntArrayRef> stride,
bool resize_storage = true) {
if (self->sizes() == size && (!stride || self->strides() == stride)) {
return self;
}
@ -57,7 +58,9 @@ inline TensorImpl* resize_impl_cpu_(
self->set_sizes_contiguous(size);
storage_size = self->numel();
}
maybe_resize_storage_cpu(self, storage_size);
if (resize_storage) {
maybe_resize_storage_cpu(self, storage_size);
}
return self;
}

View file

@ -1693,6 +1693,7 @@
CPU: resize_
CUDA: resize_cuda_
QuantizedCPU: quantized_resize_cpu_
Meta: resize_meta_
- func: empty_quantized(int[] size, Tensor qtensor) -> Tensor
use_c10_dispatcher: full

View file

@ -2540,6 +2540,22 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
z = x + y
self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
def test_upsample_nearest1d_meta(self):
# TODO: this is not a sustainable way of testing meta functions,
# but I want some quick scaffolding first before a more
# integrated testing strategy
# NB: Can't make the exponent too big, or it will overflow
# signed 64-bit integer
x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8)
z = torch.nn.functional.interpolate(x, scale_factor=2)
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
# interpolate doesn't seem to support out=
# (not sure why passing None here doesn't work? How strange...)
z = torch.empty_meta(0)
torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
def test_normal_shape(self):
warned = False
for device in torch.testing.get_all_device_types():

View file

@ -244,7 +244,11 @@ class RegisterDispatchKey:
assert_never(f)
def gen_structured(self, g: StructuredNativeFunctions) -> List[str]:
if self.dispatch_key not in g.out.dispatch:
if self.dispatch_key == 'Meta':
assert self.dispatch_key not in g.out.dispatch, \
"Do not explicitly specify Meta dispatch key on structured " \
"functions, they will be automatically generated for you"
elif self.dispatch_key not in g.out.dispatch:
return []
# Inner helper function to close over g
@ -272,14 +276,15 @@ class RegisterDispatchKey:
sig = NativeSignature.from_schema(f.func)
if self.target is Target.DEFINITION:
out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"
# TODO: work a little harder to generate fresh names for 'result'
# TODO: less praying that I picked the right argument name for 'self'
if k is SchemaKind.functional:
out_expr = "result"
prologue = "auto result = tensor_from_meta(meta_result);"
if self.dispatch_key == "Meta":
prologue = "auto result = meta_tensor_from_meta(meta_result);"
else:
prologue = "auto result = tensor_from_meta(meta_result);"
elif k is SchemaKind.inplace:
out_expr = "self"
prologue = "// TODO: consistency check assert"
@ -294,6 +299,12 @@ class RegisterDispatchKey:
{out_expr}.resize_(meta_result.sizes);
"""
if self.dispatch_key == "Meta":
out_impl_call = "// meta function does nothing"
else:
out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"
out_impl_call = f"{out_impl_name}({out_expr}, {functional_exprs});"
device_guard = ""
if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key):
@ -317,7 +328,7 @@ class RegisterDispatchKey:
{device_guard}
auto meta_result = meta::{meta_name}({functional_exprs});
{prologue}
{out_impl_name}({out_expr}, {functional_exprs});
{out_impl_call}
return {out_expr};
}}
"""
@ -1048,6 +1059,7 @@ def main() -> None:
# TODO: how come ValuesView isn't a Sequence lol
grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
structured_native_functions = [g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)]
template_dir = os.path.join(options.source_path, "templates")
@ -1093,6 +1105,9 @@ def main() -> None:
"QuantizedCUDA",
"Math",
"DefaultBackend",
# Meta is a magic key: it is automatically generated for structured
# kernels
"Meta",
]
if options.backend_whitelist:
dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist]
@ -1129,9 +1144,7 @@ def main() -> None:
})
cpu_fm.write('MetaFunctions.h', lambda: {
'declarations':
list(mapMaybe(compute_meta_function_declaration,
(g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)))),
'declarations': list(map(compute_meta_function_declaration, structured_native_functions)),
})
schema_selector = selector