mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Structured kernels generate Meta registrations (#48116)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48116 If you port kernels to be structured, you get Meta kernels automatically generated for you. This is one payoff of structured kernels. Code generation was mercifully really simple, although at risk of "swiss cheese" syndrome: there's two new conditionals in the codegen to tweak behavior when generating for meta keys. It's not too bad right now but there's a risk of things getting out of hand. One way to rationalize the logic here would be to transmit "TensorMeta-ness" inside the TensorOptions (so tensor_from_meta can deal with it); then the "Meta" kernel magic would literally just be generating empty out_impls to call after all the scaffolding is done. But I didn't do this because it seemed like it would be more annoying short term. Also had to teach resize_ to work on meta tensors, since we use them to implement the out kernels. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer, ailzhang Differential Revision: D25056640 Pulled By: ezyang fbshipit-source-id: f8fcfa0dbb58a94d9b4196748f56e155f83b1521
This commit is contained in:
parent
47db191f0c
commit
b4f5efa7b2
7 changed files with 67 additions and 12 deletions
|
|
@ -131,6 +131,7 @@ genrule(
|
|||
"aten/src/ATen/RegisterQuantizedCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCPU.cpp",
|
||||
"aten/src/ATen/RegisterMath.cpp",
|
||||
"aten/src/ATen/RegisterMeta.cpp",
|
||||
"aten/src/ATen/RegisterDefaultBackend.cpp",
|
||||
"aten/src/ATen/RegisterSchema.cpp",
|
||||
"aten/src/ATen/Functions.h",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ struct TensorMeta {
|
|||
: sizes(_sizes), options(_options) {}
|
||||
};
|
||||
|
||||
inline Tensor meta_tensor_from_meta(const TensorMeta& meta) {
|
||||
// TODO: eliminate indirection
|
||||
return at::empty_meta(meta.sizes, meta.options);
|
||||
}
|
||||
|
||||
inline Tensor tensor_from_meta(const TensorMeta& meta) {
|
||||
// TODO: eliminate indirection
|
||||
return at::empty(meta.sizes, meta.options);
|
||||
|
|
|
|||
|
|
@ -77,12 +77,13 @@ Tensor& resize_as_(
|
|||
Tensor& resize_(
|
||||
Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
c10::optional<MemoryFormat> optional_memory_format,
|
||||
bool resize_storage) {
|
||||
if (self.has_names()) {
|
||||
return resize_named_tensor_(self, size, optional_memory_format);
|
||||
}
|
||||
auto* self_ = self.unsafeGetTensorImpl();
|
||||
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
|
||||
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage);
|
||||
if (optional_memory_format.has_value()) {
|
||||
auto memory_format =
|
||||
optional_memory_format.value();
|
||||
|
|
@ -95,5 +96,20 @@ Tensor& resize_(
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor& resize_(
|
||||
Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
return resize_(self, size, optional_memory_format, /*resize_storage=*/true);
|
||||
}
|
||||
|
||||
Tensor& resize_meta_(
|
||||
Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
// meta tensors don't have storage, so don't resize them
|
||||
return resize_(self, size, optional_memory_format, /*resize_storage=*/false);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size)
|
|||
inline TensorImpl* resize_impl_cpu_(
|
||||
TensorImpl* self,
|
||||
IntArrayRef size,
|
||||
c10::optional<IntArrayRef> stride) {
|
||||
c10::optional<IntArrayRef> stride,
|
||||
bool resize_storage = true) {
|
||||
if (self->sizes() == size && (!stride || self->strides() == stride)) {
|
||||
return self;
|
||||
}
|
||||
|
|
@ -57,7 +58,9 @@ inline TensorImpl* resize_impl_cpu_(
|
|||
self->set_sizes_contiguous(size);
|
||||
storage_size = self->numel();
|
||||
}
|
||||
maybe_resize_storage_cpu(self, storage_size);
|
||||
if (resize_storage) {
|
||||
maybe_resize_storage_cpu(self, storage_size);
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1693,6 +1693,7 @@
|
|||
CPU: resize_
|
||||
CUDA: resize_cuda_
|
||||
QuantizedCPU: quantized_resize_cpu_
|
||||
Meta: resize_meta_
|
||||
|
||||
- func: empty_quantized(int[] size, Tensor qtensor) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
|
|
|||
|
|
@ -2540,6 +2540,22 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
z = x + y
|
||||
self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
|
||||
|
||||
def test_upsample_nearest1d_meta(self):
|
||||
# TODO: this is not a sustainable way of testing meta functions,
|
||||
# but I want some quick scaffolding first before a more
|
||||
# integrated testing strategy
|
||||
# NB: Can't make the exponent too big, or it will overflow
|
||||
# signed 64-bit integer
|
||||
x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8)
|
||||
z = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
|
||||
|
||||
# interpolate doesn't seem to support out=
|
||||
# (not sure why passing None here doesn't work? How strange...)
|
||||
z = torch.empty_meta(0)
|
||||
torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
|
||||
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
|
||||
|
||||
def test_normal_shape(self):
|
||||
warned = False
|
||||
for device in torch.testing.get_all_device_types():
|
||||
|
|
|
|||
|
|
@ -244,7 +244,11 @@ class RegisterDispatchKey:
|
|||
assert_never(f)
|
||||
|
||||
def gen_structured(self, g: StructuredNativeFunctions) -> List[str]:
|
||||
if self.dispatch_key not in g.out.dispatch:
|
||||
if self.dispatch_key == 'Meta':
|
||||
assert self.dispatch_key not in g.out.dispatch, \
|
||||
"Do not explicitly specify Meta dispatch key on structured " \
|
||||
"functions, they will be automatically generated for you"
|
||||
elif self.dispatch_key not in g.out.dispatch:
|
||||
return []
|
||||
|
||||
# Inner helper function to close over g
|
||||
|
|
@ -272,14 +276,15 @@ class RegisterDispatchKey:
|
|||
sig = NativeSignature.from_schema(f.func)
|
||||
|
||||
if self.target is Target.DEFINITION:
|
||||
out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"
|
||||
|
||||
# TODO: work a little harder to generate fresh names for 'result'
|
||||
# TODO: less praying that I picked the right argument name for 'self'
|
||||
|
||||
if k is SchemaKind.functional:
|
||||
out_expr = "result"
|
||||
prologue = "auto result = tensor_from_meta(meta_result);"
|
||||
if self.dispatch_key == "Meta":
|
||||
prologue = "auto result = meta_tensor_from_meta(meta_result);"
|
||||
else:
|
||||
prologue = "auto result = tensor_from_meta(meta_result);"
|
||||
elif k is SchemaKind.inplace:
|
||||
out_expr = "self"
|
||||
prologue = "// TODO: consistency check assert"
|
||||
|
|
@ -294,6 +299,12 @@ class RegisterDispatchKey:
|
|||
{out_expr}.resize_(meta_result.sizes);
|
||||
"""
|
||||
|
||||
if self.dispatch_key == "Meta":
|
||||
out_impl_call = "// meta function does nothing"
|
||||
else:
|
||||
out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"
|
||||
out_impl_call = f"{out_impl_name}({out_expr}, {functional_exprs});"
|
||||
|
||||
device_guard = ""
|
||||
|
||||
if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key):
|
||||
|
|
@ -317,7 +328,7 @@ class RegisterDispatchKey:
|
|||
{device_guard}
|
||||
auto meta_result = meta::{meta_name}({functional_exprs});
|
||||
{prologue}
|
||||
{out_impl_name}({out_expr}, {functional_exprs});
|
||||
{out_impl_call}
|
||||
return {out_expr};
|
||||
}}
|
||||
"""
|
||||
|
|
@ -1048,6 +1059,7 @@ def main() -> None:
|
|||
|
||||
# TODO: how come ValuesView isn't a Sequence lol
|
||||
grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
|
||||
structured_native_functions = [g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)]
|
||||
|
||||
template_dir = os.path.join(options.source_path, "templates")
|
||||
|
||||
|
|
@ -1093,6 +1105,9 @@ def main() -> None:
|
|||
"QuantizedCUDA",
|
||||
"Math",
|
||||
"DefaultBackend",
|
||||
# Meta is a magic key: it is automatically generated for structured
|
||||
# kernels
|
||||
"Meta",
|
||||
]
|
||||
if options.backend_whitelist:
|
||||
dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist]
|
||||
|
|
@ -1129,9 +1144,7 @@ def main() -> None:
|
|||
})
|
||||
|
||||
cpu_fm.write('MetaFunctions.h', lambda: {
|
||||
'declarations':
|
||||
list(mapMaybe(compute_meta_function_declaration,
|
||||
(g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)))),
|
||||
'declarations': list(map(compute_meta_function_declaration, structured_native_functions)),
|
||||
})
|
||||
|
||||
schema_selector = selector
|
||||
|
|
|
|||
Loading…
Reference in a new issue