diff --git a/BUILD.bazel b/BUILD.bazel index 747ff1697d3..621494b3dc7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -131,6 +131,7 @@ genrule( "aten/src/ATen/RegisterQuantizedCPU.cpp", "aten/src/ATen/RegisterSparseCPU.cpp", "aten/src/ATen/RegisterMath.cpp", + "aten/src/ATen/RegisterMeta.cpp", "aten/src/ATen/RegisterDefaultBackend.cpp", "aten/src/ATen/RegisterSchema.cpp", "aten/src/ATen/Functions.h", diff --git a/aten/src/ATen/TensorMeta.h b/aten/src/ATen/TensorMeta.h index abca65feda1..2aac50745be 100644 --- a/aten/src/ATen/TensorMeta.h +++ b/aten/src/ATen/TensorMeta.h @@ -14,6 +14,11 @@ struct TensorMeta { : sizes(_sizes), options(_options) {} }; +inline Tensor meta_tensor_from_meta(const TensorMeta& meta) { + // TODO: eliminate indirection + return at::empty_meta(meta.sizes, meta.options); +} + inline Tensor tensor_from_meta(const TensorMeta& meta) { // TODO: eliminate indirection return at::empty(meta.sizes, meta.options); diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index d6da309d4cf..e5a0423e493 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -77,12 +77,13 @@ Tensor& resize_as_( Tensor& resize_( Tensor& self, IntArrayRef size, - c10::optional optional_memory_format) { + c10::optional optional_memory_format, + bool resize_storage) { if (self.has_names()) { return resize_named_tensor_(self, size, optional_memory_format); } auto* self_ = self.unsafeGetTensorImpl(); - resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt); + resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage); if (optional_memory_format.has_value()) { auto memory_format = optional_memory_format.value(); @@ -95,5 +96,20 @@ Tensor& resize_( return self; } +Tensor& resize_( + Tensor& self, + IntArrayRef size, + c10::optional optional_memory_format) { + return resize_(self, size, optional_memory_format, /*resize_storage=*/true); +} + +Tensor& resize_meta_( + Tensor& self, + IntArrayRef size, + c10::optional optional_memory_format) { + // meta tensors don't have storage, so don't resize them + return resize_(self, size, optional_memory_format, /*resize_storage=*/false); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 501cacfbd07..d3d8faf3aa2 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -43,7 +43,8 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size) inline TensorImpl* resize_impl_cpu_( TensorImpl* self, IntArrayRef size, - c10::optional stride) { + c10::optional stride, + bool resize_storage = true) { if (self->sizes() == size && (!stride || self->strides() == stride)) { return self; } @@ -57,7 +58,9 @@ inline TensorImpl* resize_impl_cpu_( self->set_sizes_contiguous(size); storage_size = self->numel(); } - maybe_resize_storage_cpu(self, storage_size); + if (resize_storage) { + maybe_resize_storage_cpu(self, storage_size); + } return self; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 93f2f7ce22a..3441646aa19 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1693,6 +1693,7 @@ CPU: resize_ CUDA: resize_cuda_ QuantizedCPU: quantized_resize_cpu_ + Meta: resize_meta_ - func: empty_quantized(int[] size, Tensor qtensor) -> Tensor use_c10_dispatcher: full diff --git a/test/test_torch.py b/test/test_torch.py index 949e586d3b4..6d3d745ad7a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2540,6 +2540,22 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], z = x + y self.assertEqual(z.size(), (2 ** 20, 2 ** 20)) + def test_upsample_nearest1d_meta(self): + # TODO: this is not a sustainable way of testing meta functions, + # but I want some quick scaffolding first before a more + # integrated testing strategy + # NB: Can't make the exponent too big, or it will overflow + # signed 64-bit integer + x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8) + z = torch.nn.functional.interpolate(x, scale_factor=2) + self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) + + # interpolate doesn't seem to support out= + # (not sure why passing None here doesn't work? How strange...) + z = torch.empty_meta(0) + torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z) + self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) + def test_normal_shape(self): warned = False for device in torch.testing.get_all_device_types(): diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index a0a1f21d13d..d90fe1b17f4 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -244,7 +244,11 @@ class RegisterDispatchKey: assert_never(f) def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: - if self.dispatch_key not in g.out.dispatch: + if self.dispatch_key == 'Meta': + assert self.dispatch_key not in g.out.dispatch, \ + "Do not explicitly specify Meta dispatch key on structured " \ + "functions, they will be automatically generated for you" + elif self.dispatch_key not in g.out.dispatch: return [] # Inner helper function to close over g @@ -272,14 +276,15 @@ class RegisterDispatchKey: sig = NativeSignature.from_schema(f.func) if self.target is Target.DEFINITION: - out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}" - # TODO: work a little harder to generate fresh names for 'result' # TODO: less praying that I picked the right argument name for 'self' if k is SchemaKind.functional: out_expr = "result" - prologue = "auto result = tensor_from_meta(meta_result);" + if self.dispatch_key == "Meta": + prologue = "auto result = meta_tensor_from_meta(meta_result);" + else: + prologue = "auto result = tensor_from_meta(meta_result);" elif k is SchemaKind.inplace: out_expr = "self" prologue = "// TODO: consistency check assert" @@ -294,6 +299,12 @@ class RegisterDispatchKey: {out_expr}.resize_(meta_result.sizes); """ + if self.dispatch_key == "Meta": + out_impl_call = "// meta function does nothing" + else: + out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}" + out_impl_call = f"{out_impl_name}({out_expr}, {functional_exprs});" + device_guard = "" if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key): @@ -317,7 +328,7 @@ class RegisterDispatchKey: {device_guard} auto meta_result = meta::{meta_name}({functional_exprs}); {prologue} - {out_impl_name}({out_expr}, {functional_exprs}); + {out_impl_call} return {out_expr}; }} """ @@ -1048,6 +1059,7 @@ def main() -> None: # TODO: how come ValuesView isn't a Sequence lol grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))) + structured_native_functions = [g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)] template_dir = os.path.join(options.source_path, "templates") @@ -1093,6 +1105,9 @@ def main() -> None: "QuantizedCUDA", "Math", "DefaultBackend", + # Meta is a magic key: it is automatically generated for structured + # kernels + "Meta", ] if options.backend_whitelist: dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist] @@ -1129,9 +1144,7 @@ def main() -> None: }) cpu_fm.write('MetaFunctions.h', lambda: { - 'declarations': - list(mapMaybe(compute_meta_function_declaration, - (g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)))), + 'declarations': list(map(compute_meta_function_declaration, structured_native_functions)), }) schema_selector = selector