Revert "Make Context to be Device-agnostic Step by Step (2/N) (#136526)"

This reverts commit 8aacbee8e0.

Reverted https://github.com/pytorch/pytorch/pull/136526 on behalf of https://github.com/wdvr due to this one has failing internal tests, not related to a landrace with #138398 - reverting this one ([comment](https://github.com/pytorch/pytorch/pull/136526#issuecomment-2430460176))
This commit is contained in:
PyTorch MergeBot 2024-10-22 22:53:56 +00:00
parent 39bfba3f56
commit 10f16cc7da
20 changed files with 83 additions and 73 deletions

View file

@ -43,9 +43,19 @@ class TORCH_API Context {
if (device_type == at::kCPU) {
return at::detail::getDefaultCPUGenerator();
} else if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks().getDefaultMPSGenerator();
} else if (device_type == at::kXPU) {
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
} else if (device_type == at::kIPU) {
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
device.index());
} else {
return getAcceleratorHooksInterface(device_type)
.getDefaultGenerator(device.index());
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
}
}

View file

@ -102,7 +102,7 @@ void CUDAHooks::init() const {
#endif
}
const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const {
const Generator& CUDAHooks::getDefaultCUDAGenerator(DeviceIndex device_index) const {
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
}

View file

@ -21,8 +21,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
void init() const override;
Device getDeviceFromPtr(void* data) const override;
bool isPinnedPtr(const void* data) const override;
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
bool hasCUDA() const override;
bool hasMAGMA() const override;
bool hasCuDNN() const override;

View file

@ -1,13 +1,9 @@
#pragma once
#include <ATen/core/Generator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Stream.h>
#include <c10/core/Allocator.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
// AcceleratorHooksInterface is a shared interface provided by all
@ -62,18 +58,7 @@ struct TORCH_API AcceleratorHooksInterface {
virtual Device getDeviceFromPtr(void* data) const {
TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()");
}
virtual const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()");
}
virtual Generator getNewGenerator(
C10_UNUSED DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()");
}
};
} // namespace at
C10_DIAGNOSTIC_POP()

View file

@ -6,13 +6,16 @@
#include <ATen/detail/AcceleratorHooksInterface.h>
// NB: Class must live in `at` due to limitations of Registry.h.
// Forward-declares at::Generator and at::cuda::NVRTC
namespace at {
// Forward-declares at::cuda::NVRTC
struct Generator;
namespace cuda {
struct NVRTC;
} // namespace cuda
} // namespace at
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {
#ifdef _MSC_VER
constexpr const char* CUDA_HELP =
@ -66,8 +69,8 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
}
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
virtual const Generator& getDefaultCUDAGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const {
TORCH_CHECK(
false,
"Cannot get default CUDA generator without ATen_cuda library. ",

View file

@ -1,13 +1,19 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <memory>
namespace at {
class Context;
}
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {
@ -24,9 +30,8 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
}
const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
}
virtual bool hasHIP() const {
@ -45,6 +50,10 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Pinned memory requires HIP.");
}
virtual void registerHIPTypes(Context*) const {
AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
}
virtual int getNumGPUs() const {
return 0;
}

View file

@ -1,5 +1,6 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
@ -8,7 +9,7 @@
namespace at {
struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface {
struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
~IPUHooksInterface() override = default;
void init() const override {
@ -20,14 +21,16 @@ struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface {
return false;
}
const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
virtual const Generator& getDefaultIPUGenerator(
DeviceIndex device_index [[maybe_unused]] = -1) const {
AT_ERROR(
"Cannot get the default IPU generator: the IPU backend is not "
"available.");
}
Generator getNewGenerator(
DeviceIndex device_index [[maybe_unused]] = -1) const override {
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
virtual Generator newIPUGenerator(DeviceIndex device_index [[maybe_unused]] = -1) const {
AT_ERROR(
"Cannot create a new IPU generator: the IPU backend is not available.");
}
};

View file

@ -2,9 +2,9 @@
#pragma once
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
@ -31,8 +31,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const override {
virtual const Generator& getDefaultMPSGenerator() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual Allocator* getMPSDeviceAllocator() const {

View file

@ -1,20 +1,18 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Storage.h>
#include <c10/util/Exception.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
~PrivateUse1HooksInterface() override = default;
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
virtual const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
@ -26,17 +24,17 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
}
bool isPinnedPtr(const void* data) const override {
virtual bool isPinnedPtr(const void* data) const override {
return false;
}
Allocator* getPinnedMemoryAllocator() const override {
virtual Allocator* getPinnedMemoryAllocator() const override {
TORCH_CHECK(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");

View file

@ -4,6 +4,7 @@
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
@ -31,15 +32,15 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
}
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false, "Cannot get default XPU generator without ATen_xpu library.");
virtual Generator getXPUGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
virtual const Generator& getDefaultXPUGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const {
TORCH_CHECK(
false, "Cannot get default XPU generator without ATen_xpu library.");
}
virtual DeviceIndex getNumGPUs() const {

View file

@ -19,8 +19,7 @@ struct MPSHooks : public at::MPSHooksInterface {
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
// MPSGeneratorImpl interface
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
const Generator& getDefaultMPSGenerator() const override;
// MPSStream interface
void deviceSynchronize() const override;

View file

@ -59,7 +59,7 @@ Allocator* MPSHooks::getMPSDeviceAllocator() const {
return at::mps::GetMPSAllocator();
}
const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex device_index) const {
const Generator& MPSHooks::getDefaultMPSGenerator() const {
return at::mps::detail::getDefaultMPSGenerator();
}

View file

@ -2363,7 +2363,8 @@ DropoutState& get_dropout_state(
std::unique_lock<std::mutex> lock{state_cache_mut};
auto& state = dropout_state_cache.at(device);
if (train && dropout_p > 0) {
const auto& gen = at::detail::getCUDAHooks().getDefaultGenerator(device);
const auto& gen =
at::detail::getCUDAHooks().getDefaultCUDAGenerator(device);
auto gen_impl = gen.get<at::CUDAGeneratorImpl>();
bool reset_rnn_state = gen_impl->reset_rnn_state();
if (!state.buffer.defined() || reset_rnn_state) {

View file

@ -34,12 +34,13 @@ int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const {
#endif
}
const Generator& XPUHooks::getDefaultGenerator(DeviceIndex device_index) const {
return at::xpu::detail::getDefaultXPUGenerator(device_index);
Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const {
return make_generator<at::XPUGeneratorImpl>(device_index);
}
Generator XPUHooks::getNewGenerator(DeviceIndex device_index) const {
return make_generator<at::XPUGeneratorImpl>(device_index);
const Generator& XPUHooks::getDefaultXPUGenerator(
DeviceIndex device_index) const {
return at::xpu::detail::getDefaultXPUGenerator(device_index);
}
Device XPUHooks::getDeviceFromPtr(void* data) const {

View file

@ -11,9 +11,9 @@ struct XPUHooks : public at::XPUHooksInterface {
bool hasXPU() const override;
std::string showConfig() const override;
int32_t getGlobalIdxFromDevice(const at::Device& device) const override;
const Generator& getDefaultGenerator(
Generator getXPUGenerator(DeviceIndex device_index = -1) const override;
const Generator& getDefaultXPUGenerator(
DeviceIndex device_index = -1) const override;
Generator getNewGenerator(DeviceIndex device_index = -1) const override;
Device getDeviceFromPtr(void* data) const override;
c10::DeviceIndex getNumGPUs() const override;
DeviceIndex current_device() const override;

View file

@ -73,9 +73,9 @@ static PyObject* THPGenerator_pynew(
}
#endif
else if (device.type() == at::kXPU) {
self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index());
self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
} else if (device.type() == at::kIPU) {
self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index());
self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index());
} else if (device.type() == at::kPrivateUse1) {
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
} else {

View file

@ -28,7 +28,7 @@ bool cudnn_is_available() {
void manual_seed(uint64_t seed) {
if (is_available()) {
auto index = at::detail::getCUDAHooks().getCurrentDevice();
auto gen = at::detail::getCUDAHooks().getDefaultGenerator(index);
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
@ -41,7 +41,8 @@ void manual_seed(uint64_t seed) {
void manual_seed_all(uint64_t seed) {
auto num_gpu = device_count();
for (const auto i : c10::irange(num_gpu)) {
auto gen = at::detail::getCUDAHooks().getDefaultGenerator(i);
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(
static_cast<c10::DeviceIndex>(i));
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());

View file

@ -10,7 +10,7 @@ bool is_available() {
/// Sets the seed for the MPS's default generator.
void manual_seed(uint64_t seed) {
if (is_available()) {
auto gen = at::detail::getMPSHooks().getDefaultGenerator();
auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());

View file

@ -14,7 +14,7 @@ bool is_available() {
void manual_seed(uint64_t seed) {
if (is_available()) {
auto index = at::detail::getXPUHooks().getCurrentDevice();
auto gen = at::detail::getXPUHooks().getDefaultGenerator(index);
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
@ -27,7 +27,8 @@ void manual_seed(uint64_t seed) {
void manual_seed_all(uint64_t seed) {
auto num_gpu = device_count();
for (const auto i : c10::irange(num_gpu)) {
auto gen = at::detail::getXPUHooks().getDefaultGenerator(i);
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(
static_cast<c10::DeviceIndex>(i));
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());

View file

@ -44,7 +44,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator(
HANDLE_TH_ERRORS
track_bad_mps_fork();
return THPGenerator_initDefaultGenerator(
at::detail::getMPSHooks().getDefaultGenerator());
at::detail::getMPSHooks().getDefaultMPSGenerator());
END_HANDLE_TH_ERRORS
}