From 0f81a69a9630a5c9e70d534679277188fa0324ae Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 3 Mar 2021 11:19:08 -0800 Subject: [PATCH] Make meta a device (getting rid of empty_meta) (#53143) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53143 Meta is now an honest to goodness device type, like cpu, so you can use device='meta' to trigger allocation of meta tensors. This way better than empty_meta since we now have working API for most factory functions (they don't necessarily work yet, though, because need to register Meta versions of those functions.) Some subtleties: - I decided to drop the concept of CPU versus CUDA meta tensors; meta tensors are device agnostic. It's hard to say exactly what the correct level of abstraction here is, but in this particular case implementation considerations trump semantic considerations: it is way easier to have just a meta device, than to have a meta device AND a cpu device AND a cuda device. This may limit the applicability of meta tensors for tracing models that do explicit cpu()/cuda() conversions (unless, perhaps, we make those operations no-ops on meta tensors). - I noticed that the DeviceType uppercase strings are kind of weird. Are they really supposed to be all caps? That's weird. - I moved the Meta dispatch key to live with the rest of the "device" dispatch keys. - I intentionally did NOT add a Backend for Meta. For now, I'm going to hope meta tensors never exercise any of the Backend conversion code; even if it does, better to fix the code to just stop converting to and from Backend. Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: samestep Differential Revision: D26763552 Pulled By: ezyang fbshipit-source-id: 14633b6ca738e60b921db66a763155d01795480d --- aten/src/ATen/TensorIterator.cpp | 12 +--- aten/src/ATen/detail/MetaGuardImpl.cpp | 9 +++ aten/src/ATen/native/MetaTensor.cpp | 7 +-- aten/src/ATen/native/native_functions.yaml | 3 +- c10/core/Device.cpp | 3 +- c10/core/DeviceType.cpp | 3 + c10/core/DeviceType.h | 6 +- c10/core/DispatchKey.h | 59 +++---------------- c10/core/TensorOptions.h | 4 ++ .../check_backward_compatibility.py | 1 + test/test_torch.py | 8 +-- tools/codegen/dest/register_dispatch_key.py | 2 +- torch/csrc/autograd/init.cpp | 1 + torch/library.h | 2 + torch/overrides.py | 1 - .../distributed/nn/api/remote_module_test.py | 3 +- 16 files changed, 43 insertions(+), 81 deletions(-) create mode 100644 aten/src/ATen/detail/MetaGuardImpl.cpp diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 1a17eb86eaa..79d94540fd5 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -1373,17 +1373,9 @@ void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayR TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); if (!op.tensor.defined()) { if (strides.empty()) { - if (is_meta_) { - op.tensor = at::empty_meta(sizes, options); - } else { - op.tensor = at::empty(sizes, options); - } + op.tensor = at::empty(sizes, options); } else { - if (is_meta_) { - TORCH_INTERNAL_ASSERT(0, "meta strided not yet implemented"); - } else { - op.tensor = at::empty_strided(sizes, strides, options); - } + op.tensor = at::empty_strided(sizes, strides, options); } op.current_dtype = op.target_dtype; } else if (op.will_resize) { diff --git a/aten/src/ATen/detail/MetaGuardImpl.cpp b/aten/src/ATen/detail/MetaGuardImpl.cpp new file mode 100644 index 00000000000..f169f077f1e --- /dev/null +++ b/aten/src/ATen/detail/MetaGuardImpl.cpp @@ -0,0 +1,9 @@ +#include +#include + +namespace at { +namespace detail { + +C10_REGISTER_GUARD_IMPL(Meta, c10::impl::NoOpDeviceGuardImpl); + +}} // namespace at::detail diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index af293d7ebe2..ab330c92894 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -4,7 +4,6 @@ namespace at { namespace native { -// Will be promoted to a public API later, but not now Tensor empty_meta( IntArrayRef size, c10::optional dtype, @@ -16,11 +15,7 @@ Tensor empty_meta( // TODO: deduplicate this logic with empty_cpu auto tensor = detail::make_tensor( - // NB: We include the computed dispatch key, not because it will actually - // participate in dispatch, but so that tests like is_sparse/is_cuda - // give the correct result (a CUDA meta tensor "is cuda"). If we don't - // like this, remove the computeDispatchKey line - DispatchKeySet{DispatchKey::Meta, computeDispatchKey(dtype, layout, device)}, + DispatchKeySet{DispatchKey::Meta}, scalarTypeToTypeMeta(dtype_or_default(dtype)), device ); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ec906186ebc..a4ec0774922 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1570,8 +1570,6 @@ CPU: _embedding_bag_per_sample_weights_backward_cpu CUDA: _embedding_bag_per_sample_weights_backward_cuda -- func: empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - - func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False @@ -1580,6 +1578,7 @@ dispatch: CPU: empty_cpu CUDA: empty_cuda + Meta: empty_meta MkldnnCPU: empty_mkldnn SparseCPU, SparseCUDA: empty_sparse diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index cf79724035e..4ce305aee62 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -47,6 +47,7 @@ DeviceType parse_type(const std::string& device_string) { {"xla", DeviceType::XLA}, {"vulkan", DeviceType::Vulkan}, {"mlc", DeviceType::MLC}, + {"meta", DeviceType::Meta}, }}; auto device = std::find_if( types.begin(), @@ -58,7 +59,7 @@ DeviceType parse_type(const std::string& device_string) { return device->second; } TORCH_CHECK(false, - "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ", + "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta device type at start of device string: ", device_string); } } // namespace diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index c7f0c49c032..3cdb53e9928 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -35,6 +35,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { return lower_case ? "metal" : "METAL"; case DeviceType::XPU: return lower_case ? "xpu" : "XPU"; + case DeviceType::Meta: + return lower_case ? "meta" : "META"; default: TORCH_CHECK(false, "Unknown device: ", @@ -71,6 +73,7 @@ bool isValidDeviceType(DeviceType d) { case DeviceType::Vulkan: case DeviceType::Metal: case DeviceType::XPU: + case DeviceType::Meta: return true; default: return false; diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 5bdb5eaf903..8ba366abda4 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -26,12 +26,13 @@ enum class DeviceType : int8_t { Vulkan = 10, // Vulkan Metal = 11, // Metal XPU = 12, // XPU - MLC = 13, //ML Compute / Apple + MLC = 13, // ML Compute / Apple + Meta = 14, // Meta (tensors with no data) // NB: If you add more devices: // - Change the implementations of DeviceTypeName and isValidDeviceType // in DeviceType.cpp // - Change the number below - COMPILE_TIME_MAX_DEVICE_TYPES = 14, + COMPILE_TIME_MAX_DEVICE_TYPES = 15, }; constexpr DeviceType kCPU = DeviceType::CPU; @@ -41,6 +42,7 @@ constexpr DeviceType kFPGA = DeviceType::FPGA; constexpr DeviceType kMSNPU = DeviceType::MSNPU; constexpr DeviceType kXLA = DeviceType::XLA; constexpr DeviceType kMLC = DeviceType::MLC; +constexpr DeviceType kMeta = DeviceType::Meta; constexpr DeviceType kVulkan = DeviceType::Vulkan; constexpr DeviceType kMetal = DeviceType::Metal; constexpr DeviceType kXPU = DeviceType::XPU; diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 40d952c67a4..9ad7edc56d4 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -76,6 +76,13 @@ enum class DispatchKey : uint8_t { OpenCL, IDEEP, + // A meta tensor is a tensor without any data associated with it. (They + // have also colloquially been referred to as tensors on the "null" device). + // A meta tensor can be used to dry run operators without actually doing any + // computation, e.g., add on two meta tensors would give you another meta + // tensor with the output shape and dtype, but wouldn't actually add anything. + Meta, + // Here are backends which specify more specialized operators // based on the dtype of the tensor. QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp @@ -123,58 +130,6 @@ enum class DispatchKey : uint8_t { // If you add new backend keys after PrivateUse3, please also update it here. EndOfBackendKeys = PrivateUse3, - // The meta function characterizes how an operation affects the metadata of a - // tensor (shape, dtype) without doing any of the actual computation. A - // meta tensor can be used to dry run operators without actually doing - // any computation, e.g., add on two meta tensors would give you another - // meta tensor with the output shape and dtype, but wouldn't actually - // add anything. A meta implementation typically would look something like: - // - // Tensor meta::add(const Tensor& self, const Tensor& other) { - // TORCH_CHECK(self.size().equals(other.size())); - // return at::empty_like(self, self.size()); - // } - // - // The meta function would get invoked if you ran an operator passing - // in meta tensors. The call stack in such a case would look something like - // this: - // - // at::add(x: Meta, y: Meta) { - // return [dispatch] meta::add(x: Meta, y: Meta) { - // output_shape = ... - // [dispatch] meta::empty(output_shape) { - // return ... meta tensor with output_shape but no data allocated ... - // } - // } - // } - // - // Meta functions have an important secondary function, which is they can - // be used as tensor "allocators". A typical backend implementation should - // be implemented in this way: - // - // Tensor cpu::add(const Tensor& self, const Tensor& other) { - // Tensor result = meta::add(self, other); - // // ... do the actual computation into result ... - // return result; - // } - // - // In this case, the internal at::empty_like invocation would dispatch to the - // CPU factory function, not the meta factory function. The call stack in - // this case looks like: - // - // at::add(x: CPU, y: CPU) { - // return [dispatch] cpu::add(x: CPU, y: CPU) { - // output = [direct] meta::add(x: CPU, y: CPU) { - // output_shape = ... - // [dispatch] cpu::empty(output_shape) - // } - // ... compute on output ... - // return output; - // } - // } - // - Meta, - // In some situations, it is not immediately obvious what the correct // backend for function is, because the function in question doesn't // have any "tensor" arguments. In this case, a BackendSelect function diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 5b70d1252a5..980c683512b 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -635,6 +635,8 @@ inline DispatchKey computeDispatchKey(c10::optional dtype, c10::opti return DispatchKey::Vulkan; case DeviceType::Metal: return DispatchKey::Metal; + case DeviceType::Meta: + return DispatchKey::Meta; default: TORCH_CHECK(false, "Unsupported device type for dense layout: ", device_.type()); } @@ -691,6 +693,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) { return DeviceType::XLA; } else if (tid == DispatchKey::MLC) { return DeviceType::MLC; + } else if (tid == DispatchKey::Meta) { + return DeviceType::Meta; } else if (tid == DispatchKey::SparseCPU) { return DeviceType::CPU; } else if (tid == DispatchKey::SparseCUDA) { diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 14fdefaae36..147eb9764fd 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -74,6 +74,7 @@ allow_list = [ ("aten::_foreach_addcdiv", datetime.date(2021, 2, 25)), ("aten::mkldnn_linear", datetime.date(2021, 3, 2)), ("aten::linalg_multi_dot", datetime.date(2021, 3, 25)), + ("aten::empty_meta", datetime.date(2021, 4, 1)), ] def allow_listed(schema, allow_list): diff --git a/test/test_torch.py b/test/test_torch.py index 81fab01d4c1..5c43a9380ed 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2556,8 +2556,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], self.assertEqual(output3, output2) def test_empty_meta(self): - x = torch.empty_meta(2 ** 20, 2 ** 20) - y = torch.empty_meta(2 ** 20) + x = torch.empty(2 ** 20, 2 ** 20, device='meta') + y = torch.empty(2 ** 20, device='meta') z = x + y self.assertEqual(z.size(), (2 ** 20, 2 ** 20)) self.assertRaises(RuntimeError, lambda: z[0][0].item()) @@ -2568,14 +2568,14 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], # integrated testing strategy # NB: Can't make the exponent too big, or it will overflow # signed 64-bit integer - x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8) + x = torch.empty(2 * 10 ** 8, 3, 2 * 10 ** 8, device='meta') z = torch.nn.functional.interpolate(x, scale_factor=2) self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) self.assertRaises(RuntimeError, lambda: z[0][0][0].item()) # interpolate doesn't seem to support out= # (not sure why passing None here doesn't work? How strange...) - z = torch.empty_meta(0) + z = torch.empty(0, device='meta') torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z) self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) self.assertRaises(RuntimeError, lambda: z[0][0][0].item()) diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index 19b30437709..920dcc5f2aa 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -247,7 +247,7 @@ if (C10_UNLIKELY(current_device.has_value())) { if self.dispatch_key == DispatchKey.Meta: return """ if (strides.empty()) { - outputs_[output_idx] = at::empty_meta(sizes, options); + outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); } else { TORCH_INTERNAL_ASSERT(0, "not implemented yet"); } diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index e7e3f34517b..b9e4e95efb5 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -96,6 +96,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("MSNPU", c10::DeviceType::MSNPU) .value("XLA", c10::DeviceType::XLA) .value("MLC", c10::DeviceType::MLC) + .value("Meta", c10::DeviceType::Meta) .value("Vulkan", c10::DeviceType::Vulkan) .value("Metal", c10::DeviceType::Metal); diff --git a/torch/library.h b/torch/library.h index 8aa13eaf9a3..8d07b16aa57 100644 --- a/torch/library.h +++ b/torch/library.h @@ -294,6 +294,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { return c10::DispatchKey::XLA; case c10::DeviceType::MLC: return c10::DispatchKey::MLC; + case c10::DeviceType::Meta: + return c10::DispatchKey::Meta; case c10::DeviceType::HIP: return c10::DispatchKey::HIP; case c10::DeviceType::MSNPU: diff --git a/torch/overrides.py b/torch/overrides.py index 584830048c2..885ef157b64 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -129,7 +129,6 @@ def get_ignored_functions() -> Set[Callable]: torch.cudnn_grid_sampler, torch.cudnn_is_acceptable, torch.empty, - torch.empty_meta, torch.empty_strided, torch.empty_quantized, torch.eye, diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index 849382eef94..c98ca983580 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -257,8 +257,7 @@ class RemoteModuleTest(RpcAgentTestFixture): with self.assertRaisesRegex( RuntimeError, - r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan" - " device type at start of device string", + r"Expected one of .+ device type at start of device string", ): list( self._create_remote_module_iter(