mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53143 Meta is now an honest to goodness device type, like cpu, so you can use device='meta' to trigger allocation of meta tensors. This way better than empty_meta since we now have working API for most factory functions (they don't necessarily work yet, though, because need to register Meta versions of those functions.) Some subtleties: - I decided to drop the concept of CPU versus CUDA meta tensors; meta tensors are device agnostic. It's hard to say exactly what the correct level of abstraction here is, but in this particular case implementation considerations trump semantic considerations: it is way easier to have just a meta device, than to have a meta device AND a cpu device AND a cuda device. This may limit the applicability of meta tensors for tracing models that do explicit cpu()/cuda() conversions (unless, perhaps, we make those operations no-ops on meta tensors). - I noticed that the DeviceType uppercase strings are kind of weird. Are they really supposed to be all caps? That's weird. - I moved the Meta dispatch key to live with the rest of the "device" dispatch keys. - I intentionally did NOT add a Backend for Meta. For now, I'm going to hope meta tensors never exercise any of the Backend conversion code; even if it does, better to fix the code to just stop converting to and from Backend. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: samestep Differential Revision: D26763552 Pulled By: ezyang fbshipit-source-id: 14633b6ca738e60b921db66a763155d01795480d
103 lines
3.2 KiB
C++
103 lines
3.2 KiB
C++
#include <c10/core/Device.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <exception>
|
|
#include <ostream>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <vector>
|
|
#include <regex>
|
|
|
|
// Check if compiler has working std::regex implementation
|
|
//
|
|
// Test below is adapted from https://stackoverflow.com/a/41186162
|
|
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
|
|
// Compiler has working regex. MSVC has erroneous __cplusplus.
|
|
#elif __cplusplus >= 201103L && \
|
|
(!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
|
|
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
|
|
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
|
|
(defined(_GLIBCXX_RELEASE) && \
|
|
_GLIBCXX_RELEASE > 4)))
|
|
// Compiler has working regex.
|
|
#else
|
|
static_assert(false, "Compiler does not have proper regex support.");
|
|
#endif
|
|
|
|
namespace c10 {
|
|
namespace {
|
|
DeviceType parse_type(const std::string& device_string) {
|
|
static const std::array<
|
|
std::pair<std::string, DeviceType>,
|
|
static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
|
|
types = {{
|
|
{"cpu", DeviceType::CPU},
|
|
{"cuda", DeviceType::CUDA},
|
|
{"xpu", DeviceType::XPU},
|
|
{"mkldnn", DeviceType::MKLDNN},
|
|
{"opengl", DeviceType::OPENGL},
|
|
{"opencl", DeviceType::OPENCL},
|
|
{"ideep", DeviceType::IDEEP},
|
|
{"hip", DeviceType::HIP},
|
|
{"fpga", DeviceType::FPGA},
|
|
{"msnpu", DeviceType::MSNPU},
|
|
{"xla", DeviceType::XLA},
|
|
{"vulkan", DeviceType::Vulkan},
|
|
{"mlc", DeviceType::MLC},
|
|
{"meta", DeviceType::Meta},
|
|
}};
|
|
auto device = std::find_if(
|
|
types.begin(),
|
|
types.end(),
|
|
[device_string](const std::pair<std::string, DeviceType>& p) {
|
|
return p.first == device_string;
|
|
});
|
|
if (device != types.end()) {
|
|
return device->second;
|
|
}
|
|
TORCH_CHECK(false,
|
|
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta device type at start of device string: ",
|
|
device_string);
|
|
}
|
|
} // namespace
|
|
|
|
Device::Device(const std::string& device_string) : Device(Type::CPU) {
|
|
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
|
|
|
|
// We assume gcc 5+, so we can use proper regex.
|
|
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
|
|
std::smatch match;
|
|
TORCH_CHECK(
|
|
std::regex_match(device_string, match, regex),
|
|
"Invalid device string: '", device_string, "'");
|
|
type_ = parse_type(match[1].str());
|
|
if (match[2].matched) {
|
|
try {
|
|
index_ = c10::stoi(match[2].str());
|
|
} catch (const std::exception &) {
|
|
TORCH_CHECK(false,
|
|
"Could not parse device index '", match[2].str(),
|
|
"' in device string '", device_string, "'");
|
|
}
|
|
}
|
|
validate();
|
|
}
|
|
|
|
std::string Device::str() const {
|
|
std::string str = DeviceTypeName(type(), /* lower case */ true);
|
|
if (has_index()) {
|
|
str.push_back(':');
|
|
str.append(to_string(index()));
|
|
}
|
|
return str;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, const Device& device) {
|
|
stream << device.str();
|
|
return stream;
|
|
}
|
|
|
|
} // namespace c10
|