Use Devices instead of DeviceIndexes in TensorPipe agent (#57294)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57294

With the advent of CPUs in the device maps, and to be more generic (e.g., to support AMD GPUs), and to avoid conversions when passing to Future and RRef and such, it's easier to use Devices instead of DeviceIndices. This started by just migrating the TensorPipe agent but the RPC layer is quite intertwined so I had to migrate a lot of stuff.
ghstack-source-id: 127916562

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28092733

fbshipit-source-id: 024dcb3648c5898ab13e770413c43958f04f1a8a
This commit is contained in:
Luca Wehrstedt 2021-05-01 16:09:43 -07:00 committed by Facebook GitHub Bot
parent 0c3e79b5b9
commit 0422e67336
26 changed files with 176 additions and 233 deletions

View file

@ -46,7 +46,7 @@ class TestE2EBase : public ::testing::Test {
return c10::StrongTypePtr(
nullptr,
c10::DictType::create(
c10::IntType::get(), c10::IntType::get()));
c10::StringType::get(), c10::StringType::get()));
}
return c10::StrongTypePtr(
nullptr, c10::TensorType::create(at::Tensor()));

View file

@ -11,6 +11,8 @@
TEST(TensorpipeSerialize, Base) {
// Sender serializes
auto lazyStreamCtx =
std::make_shared<torch::distributed::rpc::LazyStreamContext>(c10::kCPU);
at::Tensor t1 = torch::ones({1024}, at::ScalarType::Int);
at::Tensor t2 = torch::ones({1024}, at::ScalarType::Float);
std::vector<at::Tensor> tensors{t1, t2};
@ -26,7 +28,7 @@ TEST(TensorpipeSerialize, Base) {
torch::distributed::rpc::TensorpipeWriteBuffers sendingTpBuffers;
std::tie(sendingTpMessage, sendingTpBuffers) =
torch::distributed::rpc::tensorpipeSerialize(
std::move(sendingRpcMessage));
std::move(sendingRpcMessage), {}, lazyStreamCtx);
// Mimic receiving message descriptor: recvingTpDescriptor is a copy of
// sendingTpMessage except for the data pointers which are left null.
@ -59,7 +61,8 @@ TEST(TensorpipeSerialize, Base) {
tensorpipe::Allocation recvingTpAllocation;
torch::distributed::rpc::TensorpipeReadBuffers recvingTpBuffers;
std::tie(recvingTpAllocation, recvingTpBuffers) =
torch::distributed::rpc::tensorpipeAllocate(recvingTpDescriptor);
torch::distributed::rpc::tensorpipeAllocate(
recvingTpDescriptor, lazyStreamCtx);
// Mimic tensorpipe data transfer
EXPECT_EQ(
@ -101,6 +104,8 @@ TEST(TensorpipeSerialize, Base) {
TEST(TensorpipeSerialize, RecopySparseTensors) {
// Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows.
auto lazyStreamCtx =
std::make_shared<torch::distributed::rpc::LazyStreamContext>(c10::kCPU);
constexpr size_t k1K = 1024;
at::Tensor main = torch::randn({k1K, k1K});
at::Tensor tiny = main.select(0, 2); // Select a row in the middle
@ -118,7 +123,7 @@ TEST(TensorpipeSerialize, RecopySparseTensors) {
torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers;
std::tie(sendingTpMessage, tpBuffers) =
torch::distributed::rpc::tensorpipeSerialize(
std::move(sendingRpcMessage));
std::move(sendingRpcMessage), {}, lazyStreamCtx);
EXPECT_EQ(tpBuffers.tensors.size(), 2);
EXPECT_EQ(sendingTpMessage.tensors.size(), 2);
@ -135,6 +140,8 @@ TEST(TensorpipeSerialize, RecopySparseTensors) {
}
TEST(TensorpipeSerialize, NoDeleterTensors) {
auto lazyStreamCtx =
std::make_shared<torch::distributed::rpc::LazyStreamContext>(c10::kCPU);
std::vector<float> blob1{.8, .2};
std::vector<float> blob2{.7, .5, .9};
at::Tensor t1 = torch::from_blob((float*)(blob1.data()), blob1.size());
@ -150,7 +157,7 @@ TEST(TensorpipeSerialize, NoDeleterTensors) {
torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers;
std::tie(sendingTpMessage, tpBuffers) =
torch::distributed::rpc::tensorpipeSerialize(
std::move(sendingRpcMessage));
std::move(sendingRpcMessage), {}, lazyStreamCtx);
EXPECT_EQ(tpBuffers.copiedTensors.size(), 2);
EXPECT_EQ(sendingTpMessage.tensors.size(), 2);

View file

@ -41,7 +41,7 @@ class RpcAgent:
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[int, int]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def get_debug_info(self) -> Dict[str, str]: ...
def get_metrics(self) -> Dict[str, str]: ...
@ -91,14 +91,14 @@ class ProcessGroupAgent(RpcAgent):
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[int, int]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def join(self): ...
def shutdown(self): ...
def sync(self): ...
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
num_worker_threads: int
device_maps: Dict[str, Dict[int, int]]
device_maps: Dict[str, Dict[torch.device, torch.device]]
def __init__(
self,
num_worker_threads: int,
@ -106,9 +106,9 @@ class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
_channels: Optional[List],
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
device_maps: Dict[str, Dict[int, int]] = dict(),
devices: List[int] = list()): ...
def _set_device_map(self, to: str, device_map: Dict[int, int]): ...
device_maps: Dict[str, Dict[torch.device, torch.device]] = dict(),
devices: List[torch.device] = list()): ...
def _set_device_map(self, to: str, device_map: Dict[torch.device, torch.device]): ...
class TensorPipeAgent(RpcAgent):
def __init__(
@ -129,8 +129,8 @@ class TensorPipeAgent(RpcAgent):
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[int, int]: ...
def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[torch.device, torch.device]]): ...
def _is_current_rpc_agent_set() -> bool: ...
def _get_current_rpc_agent()-> RpcAgent: ...

View file

@ -14,7 +14,7 @@ RecvRpcBackward::RecvRpcBackward(
const AutogradMetadata& autogradMetadata,
ContextPtr autogradContext,
rpc::worker_id_t fromWorkerId,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
std::unordered_map<c10::Device, c10::Device> deviceMap)
: autogradMetadata_(autogradMetadata),
// NOLINTNEXTLINE(performance-move-const-arg)
autogradContext_(std::move(autogradContext)),

View file

@ -23,7 +23,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
const AutogradMetadata& autogradMetadata,
std::shared_ptr<DistAutogradContext> autogradContext,
rpc::worker_id_t fromWorkerId,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap);
std::unordered_map<c10::Device, c10::Device> deviceMap);
torch::autograd::variable_list apply(
torch::autograd::variable_list&& grads) override;
@ -41,7 +41,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node {
rpc::worker_id_t fromWorkerId_;
// Device mapping for tensors sent over RPC.
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
const std::unordered_map<c10::Device, c10::Device> deviceMap_;
};
} // namespace autograd

View file

@ -19,7 +19,7 @@ RpcWithAutograd::RpcWithAutograd(
MessageType messageType,
const AutogradMetadata& autogradMetadata,
rpc::Message&& wrappedMessage,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
std::unordered_map<c10::Device, c10::Device> deviceMap)
: fromWorkerId_(fromWorkerId),
messageType_(messageType),
autogradMetadata_(autogradMetadata),
@ -39,7 +39,7 @@ RpcWithAutograd::RpcWithAutograd(
std::unique_ptr<RpcCommandBase> wrappedRpc,
MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap)
std::unordered_map<c10::Device, c10::Device> deviceMap)
: fromWorkerId_(fromWorkerId),
messageType_(messageType),
autogradMetadata_(autogradMetadata),
@ -61,9 +61,9 @@ Message RpcWithAutograd::toMessageImpl() && {
TORCH_INTERNAL_ASSERT(!payload.empty());
// Convert deviceMap to c10::Dict for serialization.
c10::Dict<int64_t, int64_t> deviceMap;
c10::Dict<std::string, std::string> deviceMap;
for (const auto& mapEntry : deviceMap_) {
deviceMap.insert(mapEntry.first, mapEntry.second);
deviceMap.insert(mapEntry.first.str(), mapEntry.second.str());
}
std::vector<at::IValue> ivalues{wrappedMessageType,
@ -109,10 +109,10 @@ std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
AutogradMetadata autogradMetadata(
tupleElements[1].toInt(), tupleElements[2].toInt());
worker_id_t workerId = tupleElements[3].toInt();
auto c10DeviceMap = tupleElements[4].to<c10::Dict<int64_t, int64_t>>();
auto c10DeviceMap = tupleElements[4].to<c10::Dict<std::string, std::string>>();
// Convert to regular map.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap;
std::unordered_map<c10::Device, c10::Device> deviceMap;
for (const auto& mapEntry : c10DeviceMap) {
deviceMap.insert({mapEntry.key(), mapEntry.value()});
}
@ -169,7 +169,7 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
return fromWorkerId_;
}
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& RpcWithAutograd::
const std::unordered_map<c10::Device, c10::Device>& RpcWithAutograd::
deviceMap() {
return deviceMap_;
}

View file

@ -19,7 +19,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
rpc::MessageType messageType,
const AutogradMetadata& autogradMetadata,
rpc::Message&& wrappedMessage,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});
std::unordered_map<c10::Device, c10::Device> deviceMap = {});
// Used when receiving an RPC over the wire.
RpcWithAutograd(
@ -29,7 +29,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
rpc::MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors,
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap = {});
std::unordered_map<c10::Device, c10::Device> deviceMap = {});
rpc::Message toMessageImpl() && override;
@ -55,7 +55,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
rpc::worker_id_t fromWorkerId() const;
// Retrieve the device map.
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap();
const std::unordered_map<c10::Device, c10::Device>& deviceMap();
private:
// WorkerId from which this RPC originated. This is necessary for knowing
@ -90,7 +90,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
std::vector<torch::Tensor> tensors_;
// Device mapping for tensors that are sent across an RPC to another node.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> deviceMap_;
std::unordered_map<c10::Device, c10::Device> deviceMap_;
};
} // namespace autograd

View file

@ -53,7 +53,7 @@ ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
// Initialize autograd context if necessary.
auto& autogradContainer = DistAutogradContainer::getInstance();
auto autogradContext =
@ -105,7 +105,7 @@ Message getMessageWithAutograd(
torch::distributed::rpc::Message&& wrappedRpcMsg,
MessageType msgType,
bool forceGradRecording,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
auto& autogradContainer = DistAutogradContainer::getInstance();
// If there is no valid context and no tensor requires grads, send original

View file

@ -31,7 +31,7 @@ TORCH_API ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap);
const std::unordered_map<c10::Device, c10::Device>& deviceMap);
// This method is a wrapper utility used internally to wrap autograd info
// and attach autograd function for each type of rpc call if it has valid
@ -44,7 +44,7 @@ TORCH_API rpc::Message getMessageWithAutograd(
rpc::Message&& wrappedRpcMsg,
rpc::MessageType msgType,
bool forceGradRecording = false,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
const std::unordered_map<c10::Device, c10::Device>& deviceMap =
{});
// Send message after autograd checking

View file

@ -635,16 +635,15 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
optional<std::vector<std::string>>,
float,
std::string,
std::unordered_map<std::string, tensorpipe::DeviceMap>,
std::vector<c10::DeviceIndex>>(),
std::unordered_map<std::string, DeviceMap>,
std::vector<c10::Device>>(),
py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
py::arg("_transports") = optional<std::vector<std::string>>(),
py::arg("_channels") = optional<std::vector<std::string>>(),
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
py::arg("init_method") = kDefaultInitMethod,
py::arg("device_maps") =
std::unordered_map<std::string, tensorpipe::DeviceMap>(),
py::arg("devices") = std::vector<c10::DeviceIndex>())
py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
py::arg("devices") = std::vector<c10::Device>())
.def_readwrite(
"num_worker_threads",
&TensorPipeRpcBackendOptions::numWorkerThreads,
@ -722,8 +721,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
py::call_guard<py::gil_scoped_release>())
.def(
"_get_device_map",
(tensorpipe::DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst)
const) &
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
TensorPipeAgent::getDeviceMap,
py::call_guard<py::gil_scoped_release>())
.def(

View file

@ -270,7 +270,7 @@ std::shared_ptr<JitFuture> ProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
// Throw if we previously encountered an exception in ::listenLoop.
{
std::unique_lock<std::mutex> guard(listenLoopExceptionMutex_);

View file

@ -94,8 +94,8 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) override;
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
override;
// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);

View file

@ -361,7 +361,7 @@ void RequestCallbackNoPython::processForwardAutogradReq(
// Need to reverse the device map for the backward pass of distributed
// autograd.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> reverseDeviceMap;
std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
}

View file

@ -287,7 +287,7 @@ bool RpcAgent::isGILProfilingEnabled() {
return profilingEnabled_.load();
}
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> RpcAgent::getDeviceMap(
std::unordered_map<c10::Device, c10::Device> RpcAgent::getDeviceMap(
const WorkerInfo& /* unused */) const {
// Default implementation has no device map.
return {};

View file

@ -164,8 +164,7 @@ class TORCH_API RpcAgent {
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) = 0;
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {}) = 0;
// Retries sending the message up to maxRetries times until an ACK is
// receieved. The duration between consecutive sends is increased over
@ -265,7 +264,7 @@ class TORCH_API RpcAgent {
std::shared_ptr<TypeResolver> getTypeResolver();
// Retrieves the device map for the provided destination worker.
virtual std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> getDeviceMap(
virtual std::unordered_map<c10::Device, c10::Device> getDeviceMap(
const WorkerInfo& dst) const;
protected:

View file

@ -314,7 +314,7 @@ void OwnerRRef::recordAllStreams(
void OwnerRRef::blockAllStreams(std::shared_ptr<LazyStreamContext>& ctx) {
if (ctx) {
for (c10::Event& event : events_) {
event.block(ctx->getStream(event.device_index()));
event.block(ctx->getStream(event.device()));
}
}
}

View file

@ -41,9 +41,9 @@ const std::string kClientActiveCalls = "agent.client_active_calls";
const std::string kServerActiveCalls = "agent.server_active_calls";
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
std::vector<c10::DeviceIndex> getDevicesForTensors(
std::vector<c10::Device> getDevicesForTensors(
const std::vector<torch::Tensor>& tensors,
const tensorpipe::DeviceMap& deviceMap,
const DeviceMap& deviceMap,
const std::string& remoteName) {
// If the deviceMap is overridden, use that instead.
const auto errStr = c10::str(
@ -53,45 +53,46 @@ std::vector<c10::DeviceIndex> getDevicesForTensors(
"configure device mapping. ",
"Request device mapping is not available for destination ",
remoteName);
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(tensors.size());
std::vector<c10::Device> devices;
devices.reserve(tensors.size());
bool hasCudaTensor = false;
for (const auto& t : tensors) {
if (t.device().is_cpu()) {
deviceIndices.push_back(-1);
devices.emplace_back(c10::kCPU);
} else {
const auto deviceIter = deviceMap.find(t.device().index());
const auto deviceIter = deviceMap.find(t.device());
TORCH_CHECK(
deviceIter != deviceMap.end(),
errStr,
" for device ",
t.device(),
" but received a tensor on that device.");
deviceIndices.push_back(deviceIter->second);
devices.push_back(deviceIter->second);
hasCudaTensor = true;
}
}
if (!hasCudaTensor) {
deviceIndices.clear();
devices.clear();
}
return deviceIndices;
return devices;
}
// A helper function that first creates a LazyStreamContext and then grabs a
// CUDA stream for each device in the given device list.
std::shared_ptr<LazyStreamContext> createCalleeStreamContext(
std::vector<c10::DeviceIndex> devices) {
auto ctx = createLazyStreamContext();
for (const auto& device : devices) {
const std::vector<c10::Device>& devices) {
auto ctx = std::make_shared<LazyStreamContext>(
devices.empty() ? c10::kCPU : devices[0].type());
for (const c10::Device& device : devices) {
ctx->getStream(device);
}
return ctx;
}
// Retrieve local devices (i.e., device keys) from the given map.
std::unordered_set<c10::DeviceIndex> getLocalDevices(
const std::unordered_map<std::string, tensorpipe::DeviceMap>& deviceMap) {
std::unordered_set<c10::DeviceIndex> deviceSet;
std::unordered_set<c10::Device> getLocalDevices(
const std::unordered_map<std::string, DeviceMap>& deviceMap) {
std::unordered_set<c10::Device> deviceSet;
for (const auto& entry : deviceMap) {
for (const auto& device : entry.second) {
deviceSet.insert(device.first);
@ -103,9 +104,9 @@ std::unordered_set<c10::DeviceIndex> getLocalDevices(
// 1) checks there is no duplication in the devices field
// 2) checks all local devices in the deviceSet are included in deviceOpt
void checkValidDevicesOption(
const std::unordered_set<c10::DeviceIndex>& deviceSet,
const std::vector<c10::DeviceIndex>& deviceOpt) {
std::unordered_set<c10::DeviceIndex> optsDeviceSet(
const std::unordered_set<c10::Device>& deviceSet,
const std::vector<c10::Device>& deviceOpt) {
std::unordered_set<c10::Device> optsDeviceSet(
deviceOpt.begin(), deviceOpt.end());
// no duplications are allowed in opts_.devices
@ -114,18 +115,17 @@ void checkValidDevicesOption(
"Detected duplication in TensorPipeRpcBackendOptions devices field.");
// opts_.devices must be a superset of local devices in reverseDeviceMaps_
std::vector<c10::DeviceIndex> cut;
std::set_difference(
deviceSet.begin(),
deviceSet.end(),
optsDeviceSet.begin(),
optsDeviceSet.end(),
std::back_inserter(cut));
std::unordered_set<c10::Device> cut;
for (const c10::Device& device : deviceSet) {
if (optsDeviceSet.find(device) == optsDeviceSet.end()) {
cut.insert(device);
}
}
if (!cut.empty()) {
std::ostringstream oss;
std::copy(
cut.begin(), cut.end(), std::ostream_iterator<int32_t>(oss, ", "));
cut.begin(), cut.end(), std::ostream_iterator<c10::Device>(oss, ", "));
TORCH_CHECK(
false,
"The devices field in TensorPipeRpcBackendOptions must either be "
@ -511,8 +511,7 @@ TensorPipeAgent::~TensorPipeAgent() {
}
void TensorPipeAgent::setReverseDeviceMaps(
const std::unordered_map<std::string, tensorpipe::DeviceMap>&
reverseDeviceMaps) {
const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps) {
reverseDeviceMaps_ = reverseDeviceMaps;
// If devices wasn't specified in the options, update devices_ with local
@ -708,11 +707,9 @@ void TensorPipeAgent::pipeRead(
void TensorPipeAgent::pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
Message&& rpcMessage,
std::vector<c10::DeviceIndex>&& devices,
std::vector<c10::Device>&& devices,
std::shared_ptr<LazyStreamContext> ctx,
std::function<void(const tensorpipe::Error&)> fn,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>&
deviceMap) noexcept {
std::function<void(const tensorpipe::Error&)> fn) noexcept {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;
@ -748,7 +745,7 @@ void TensorPipeAgent::sendCompletedResponseMessage(
std::move(*futureResponseMessage->value().toCustomClass<Message>());
responseMessage.setId(messageId);
std::vector<c10::DeviceIndex> devices;
std::vector<c10::Device> devices;
try {
devices = getDevicesForRemote(pipe->getRemoteName(), responseMessage);
} catch (const std::exception& e) {
@ -760,19 +757,17 @@ void TensorPipeAgent::sendCompletedResponseMessage(
// FIXME: skipping this check when ctxDevices is empty to allow
// RRef.to_here().
for (const auto& tensor : responseMessage.tensors()) {
const auto device = tensor.device().index();
if (device != -1 && ctxDevices.find(device) == ctxDevices.end()) {
const auto device = tensor.device();
if (!device.is_cpu() && ctxDevices.find(device) == ctxDevices.end()) {
std::ostringstream oss;
std::copy(
ctxDevices.begin(),
ctxDevices.end(),
// interpreting c10::DeviceIndex as int32_t to avoid printing
// it as a char.
std::ostream_iterator<int32_t>(oss, ", "));
std::ostream_iterator<c10::Device>(oss, ", "));
responseMessage = createExceptionResponse(
c10::str(
"RPC detected that a user-function output tensor on device ",
int32_t(device),
device,
". This device is not one of the input tensor devices: ",
oss.str(),
"which is not yet supported. Please file a feature request "
@ -913,7 +908,7 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
const WorkerInfo& toWorkerInfo,
Message&& requestMessage,
const float rpcTimeoutSeconds,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
const DeviceMap& deviceMap) {
TORCH_CHECK(
requestMessage.isRequest(),
"TensorPipeAgent::send(..) is only for sending requests.");
@ -960,7 +955,7 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
// Get devices for tensors in the request message. This can throw if device
// maps are not configured properly for this request.
std::vector<c10::DeviceIndex> devices;
std::vector<c10::Device> devices;
if (deviceMap.empty()) {
devices =
getDevicesForRemote(clientPipe.pipe_->getRemoteName(), requestMessage);
@ -1006,7 +1001,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
auto ctx = createLazyStreamContext();
auto ctx = std::make_shared<LazyStreamContext>(
devices_.empty() ? c10::kCPU : devices_[0].type());
ctx->waitForCurrentStreams(requestMessage.tensors());
pipeWrite(
clientPipe.pipe_,
@ -1093,8 +1089,7 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
std::move(ctx));
}
});
},
deviceMap);
});
return futureResponseMessage->jitFuture;
}
@ -1419,7 +1414,7 @@ void TensorPipeAgent::markFutureWithError(
}
}
std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForRemote(
std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
const std::string& remoteName,
const Message& message) const {
const auto& deviceMaps =
@ -1449,8 +1444,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForRemote(
}
}
tensorpipe::DeviceMap TensorPipeAgent::getDeviceMap(
const WorkerInfo& dst) const {
DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
auto it = opts_.deviceMaps.find(dst.name_);
if (it == opts_.deviceMaps.end()) {
return {};

View file

@ -37,14 +37,14 @@ namespace channel {
class Context;
} // namespace channel
using DeviceMap = std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>;
} // namespace tensorpipe
namespace torch {
namespace distributed {
namespace rpc {
using DeviceMap = std::unordered_map<c10::Device, c10::Device>;
struct LazyStreamContext;
using steady_clock_time_point =
@ -86,8 +86,8 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions {
optional<std::vector<std::string>> channels,
float rpc_timeout,
std::string init_method,
std::unordered_map<std::string, tensorpipe::DeviceMap> device_maps = {},
std::vector<c10::DeviceIndex> devices = {})
std::unordered_map<std::string, DeviceMap> device_maps = {},
std::vector<c10::Device> devices = {})
: RpcBackendOptions(rpc_timeout, init_method),
numWorkerThreads(numWorkerThreads),
transports(std::move(transports)),
@ -119,15 +119,20 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions {
}
}
void setDeviceMap(
const std::string& workerName,
const tensorpipe::DeviceMap& deviceMap) {
void setDeviceMap(const std::string& workerName, const DeviceMap& deviceMap) {
auto iter = deviceMaps.find(workerName);
if (iter == deviceMaps.end()) {
deviceMaps[workerName] = deviceMap;
} else {
for (auto& entry : deviceMap) {
iter->second[entry.first] = entry.second;
// c10::Device has no default constructor, hence map[device] dosn't work
// In C++-17 we can use insert_or_assign.
auto entryIter = iter->second.find(entry.first);
if (entryIter == iter->second.end()) {
iter->second.emplace(entry.first, entry.second);
} else {
entryIter->second = entry.second;
}
}
}
}
@ -135,8 +140,8 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions {
int numWorkerThreads;
const optional<std::vector<std::string>> transports;
const optional<std::vector<std::string>> channels;
std::unordered_map<std::string, tensorpipe::DeviceMap> deviceMaps;
std::vector<c10::DeviceIndex> devices;
std::unordered_map<std::string, DeviceMap> deviceMaps;
std::vector<c10::Device> devices;
};
// Struct to track the network source metrics
@ -175,8 +180,7 @@ class TensorPipeAgent : public RpcAgent {
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) override;
const DeviceMap& deviceMap = {}) override;
// join() and sync() would be deprecated -
// https://github.com/pytorch/pytorch/issues/27647
@ -191,14 +195,13 @@ class TensorPipeAgent : public RpcAgent {
const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override;
std::vector<WorkerInfo> getWorkerInfos() const override;
void setReverseDeviceMaps(
const std::unordered_map<std::string, tensorpipe::DeviceMap>&
reverseDeviceMaps);
const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps);
std::unordered_map<std::string, std::string> getMetrics() override;
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
tensorpipe::DeviceMap getDeviceMap(const WorkerInfo& dest) const override;
DeviceMap getDeviceMap(const WorkerInfo& dest) const override;
using NetworkDataDict =
std::unordered_map<std::string, AggregatedNetworkData>;
@ -239,10 +242,9 @@ class TensorPipeAgent : public RpcAgent {
void pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>&,
Message&& message,
std::vector<c10::DeviceIndex>&& devices,
std::vector<c10::Device>&& devices,
std::shared_ptr<LazyStreamContext> ctx,
std::function<void(const tensorpipe::Error&)>,
const tensorpipe::DeviceMap& deviceMap = {}) noexcept;
std::function<void(const tensorpipe::Error&)>) noexcept;
// Callback of listener accept()
void onListenerAccepted(
@ -269,7 +271,7 @@ class TensorPipeAgent : public RpcAgent {
uint64_t requestSize,
const std::string& destWorkerName);
inline std::vector<c10::DeviceIndex> getDevicesForRemote(
inline std::vector<c10::Device> getDevicesForRemote(
const std::string& remoteName,
const Message& message) const;
@ -280,14 +282,9 @@ class TensorPipeAgent : public RpcAgent {
// then, it ends up printing a log message, which may worry the user. To solve
// both issues we use a separate atomic flag to know the status of the future.
struct AtomicJitFuture {
explicit AtomicJitFuture(const std::vector<c10::DeviceIndex>& devices) {
std::vector<c10::Device> fullDevices;
fullDevices.reserve(devices.size());
for (const c10::DeviceIndex index : devices) {
fullDevices.emplace_back(c10::kCUDA, index);
}
explicit AtomicJitFuture(const std::vector<c10::Device>& devices) {
jitFuture = std::make_shared<at::ivalue::Future>(
at::AnyClassType::get(), std::move(fullDevices));
at::AnyClassType::get(), devices);
}
std::atomic_flag isComplete = ATOMIC_FLAG_INIT;
@ -310,11 +307,11 @@ class TensorPipeAgent : public RpcAgent {
};
const TensorPipeRpcBackendOptions opts_;
std::unordered_map<std::string, tensorpipe::DeviceMap> reverseDeviceMaps_;
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
// Local devices used by this agent. If application didn't specify this
// field, it will be initialized using corresponding local devices in
// opts_.deviceMaps and reverseDeviceMaps_;
std::vector<c10::DeviceIndex> devices_;
std::vector<c10::Device> devices_;
ThreadPool threadPool_;
std::shared_ptr<tensorpipe::Context> context_;

View file

@ -42,7 +42,7 @@ inline c10::Device indexToDevice(c10::DeviceIndex index) {
std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
Message&& rpcMessage,
std::vector<c10::DeviceIndex> deviceIndices,
std::vector<c10::Device> devices,
const std::shared_ptr<LazyStreamContext>& ctx) {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers buffers;
@ -89,10 +89,9 @@ std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
// tensor to CPU.
const auto& tensorData =
jit::getWriteableTensorData(tensorDataVec[i], /* toCpu */ false);
tensorpipe::Device targetDevice =
deviceIndices.empty() || deviceIndices[i] == -1
tensorpipe::Device targetDevice = devices.empty() || devices[i].is_cpu()
? tensorpipe::Device{tensorpipe::kCpuDeviceType, 0}
: tensorpipe::Device{tensorpipe::kCudaDeviceType, deviceIndices[i]};
: tensorpipe::Device{tensorpipe::kCudaDeviceType, devices[i].index()};
// Enforce memory copy if tensor is created from torch::from_blob, means
// that the tensor doesn't own the memory.
@ -126,8 +125,8 @@ std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
tpMessage.tensors.push_back(std::move(tensor));
#ifdef USE_CUDA_NOT_ROCM
} else if (tensorDataVec[i].device().is_cuda()) {
auto stream = at::cuda::CUDAStream(
ctx->getStream(tensorDataVec[i].device().index()));
auto stream =
at::cuda::CUDAStream(ctx->getStream(tensorDataVec[i].device()));
tensorpipe::CudaBuffer buffer;
buffer.ptr = tensorPtr;
buffer.stream = stream.stream();
@ -209,8 +208,8 @@ std::pair<tensorpipe::Allocation, TensorpipeReadBuffers> tensorpipeAllocate(
tpAllocation.tensors[tensorIdx].buffer = buffer;
#ifdef USE_CUDA_NOT_ROCM
} else if (tensor.targetDevice->type == tensorpipe::kCudaDeviceType) {
auto deviceIndex = tensor.targetDevice->index;
auto stream = at::cuda::CUDAStream(ctx->getStream(deviceIndex));
c10::Device device(c10::kCUDA, tensor.targetDevice->index);
auto stream = at::cuda::CUDAStream(ctx->getStream(device));
// CUDACachingAllocator will call recordStream accordingly on the current
// stream.
at::cuda::CUDAStreamGuard guard(stream);

View file

@ -45,30 +45,13 @@ struct TensorpipeReadBuffers {
std::vector<c10::DataPtr> tensors;
};
inline std::shared_ptr<LazyStreamContext> createLazyStreamContext() {
return createLazyStreamContext(
#ifdef USE_CUDA_NOT_ROCM
c10::DeviceType::CUDA,
[](c10::DeviceIndex index) {
return at::cuda::getStreamFromPool(
/* isHighPriority */ false, /* device */ index);
},
[](c10::DeviceIndex index) {
return at::cuda::getCurrentCUDAStream(index);
}
#else
c10::DeviceType::CPU, nullptr, nullptr
#endif
);
}
// Convert an RPC message into a TensorPipe message, plus a holder to all the
// data that must be kept alive while the write is performed asynchronously.
TORCH_API std::tuple<tensorpipe::Message, TensorpipeWriteBuffers>
tensorpipeSerialize(
Message&& rpcMessage,
std::vector<c10::DeviceIndex> devices = {},
const std::shared_ptr<LazyStreamContext>& = createLazyStreamContext());
std::vector<c10::Device> devices,
const std::shared_ptr<LazyStreamContext>& ctx);
// Allocate the buffers that will hold the incoming data. They will be managed
// by the returned holder, which must be kept alive until the asynchronous read
@ -77,7 +60,7 @@ tensorpipeSerialize(
TORCH_API std::pair<tensorpipe::Allocation, TensorpipeReadBuffers>
tensorpipeAllocate(
const tensorpipe::Descriptor& tpDescriptor,
const std::shared_ptr<LazyStreamContext>& ctx = createLazyStreamContext());
const std::shared_ptr<LazyStreamContext>& ctx);
// Convert a TensorPipe message back into an RPC message. This requires the data
// to be available and can thus only be performed once the asynchronous read has

View file

@ -62,7 +62,7 @@ std::shared_ptr<JitFuture> FaultyProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap) {
const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
// We only fail control messages that have been specified by the test case.
// For all other messages, we just send them without any failures.
if (!shouldFailMessage(message.type())) {

View file

@ -48,8 +48,8 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
const std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>& deviceMap =
{}) override;
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
override;
protected:
// This function checks the messageTypesToFail_ to determine whether to use

View file

@ -175,7 +175,7 @@ std::unique_ptr<RpcCommandBase> deserializeResponse(
// Need to reverse the device map for the backward pass of distributed
// autograd.
std::unordered_map<c10::DeviceIndex, c10::DeviceIndex> reverseDeviceMap;
std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
}

View file

@ -89,8 +89,6 @@ TORCH_API void populateRemoteProfiledEvents(
const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
eventLists);
using stream_factory_t = std::function<c10::Stream(c10::DeviceIndex)>;
// A general device context class for both CPU and CUDA. If CUDA is not
// available, all CUDA-related methods will be no-ops.
struct TORCH_API LazyStreamContext {
@ -99,39 +97,26 @@ struct TORCH_API LazyStreamContext {
LazyStreamContext& operator=(const LazyStreamContext& rhs) = delete;
LazyStreamContext& operator=(LazyStreamContext&& rhs) & = delete;
LazyStreamContext(
c10::DeviceType device_type,
stream_factory_t stream_creator,
stream_factory_t current_stream_provider)
: device_type_(device_type),
stream_creator_(std::move(stream_creator)),
current_stream_provider_(std::move(current_stream_provider)) {}
explicit LazyStreamContext(c10::DeviceType device_type)
: impl_(device_type) {}
// let streams in this context wiat for current streams.
void waitForCurrentStreams(const std::vector<torch::Tensor>& tensors = {}) {
if (!stream_creator_) {
// Since the stream_creator is empty the device doesn't support streams
return;
}
for (const auto& tensor : tensors) {
if (tensor.is_cuda()) {
getStream(tensor.device().index());
getStream(tensor.device());
}
}
for (const auto& entry : streams_) {
c10::Event event{device_type_};
event.record(current_stream_provider_(entry.first));
c10::Event event{impl_.type()};
event.record(impl_.getStream(entry.first));
event.block(entry.second);
}
}
// get all streams used in this context
std::vector<c10::Stream> getReservedStreams() const {
if (!stream_creator_) {
// Since the stream_creator is empty the device doesn't support streams
return {};
}
std::vector<c10::Stream> reservedStreams;
reservedStreams.reserve(streams_.size());
for (const auto& entry : streams_) {
@ -142,25 +127,19 @@ struct TORCH_API LazyStreamContext {
// get a stream for the given device. If it is the first time using that
// device, allocate a new stream and store it in the map.
c10::Stream getStream(c10::DeviceIndex index) {
if (!stream_creator_) {
throw std::runtime_error(c10::str(
"Attempting to access device stream of device ",
index,
", but the device doesn't support streams"));
}
auto iter = streams_.find(index);
c10::Stream getStream(c10::Device device) {
auto iter = streams_.find(device);
if (iter == streams_.end()) {
auto stream = stream_creator_(index);
streams_.emplace(index, stream);
auto stream = impl_.getStreamFromGlobalPool(device);
streams_.emplace(device, stream);
return stream;
} else {
return iter->second;
}
}
std::set<c10::DeviceIndex> devices() const {
std::set<c10::DeviceIndex> devices;
std::unordered_set<c10::Device> devices() const {
std::unordered_set<c10::Device> devices;
for (const auto& entry : streams_) {
devices.insert(entry.first);
}
@ -168,26 +147,14 @@ struct TORCH_API LazyStreamContext {
}
c10::DeviceType deviceType() const {
return device_type_;
return impl_.type();
}
private:
std::unordered_map<c10::DeviceIndex, c10::Stream> streams_;
c10::DeviceType device_type_;
stream_factory_t stream_creator_;
stream_factory_t current_stream_provider_;
const c10::impl::VirtualGuardImpl impl_;
std::unordered_map<c10::Device, c10::Stream> streams_;
};
inline std::shared_ptr<LazyStreamContext> createLazyStreamContext(
c10::DeviceType device_type,
stream_factory_t stream_creator,
stream_factory_t current_stream_provider) {
return std::make_shared<LazyStreamContext>(
device_type,
std::move(stream_creator),
std::move(current_stream_provider));
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View file

@ -206,12 +206,12 @@ def _tensorpipe_check_local_device_maps(name, options):
f"Invalid device_map configuration for {worker_name}, "
f"not 1-to-1 mapping:\ndevice_maps = {device_map}"
)
local_devices.update(options.device_maps[worker_name].keys())
local_devices.update(key_set)
if len(local_devices) > 0 and not all([
max(local_devices) < torch.cuda.device_count(),
min(local_devices) >= 0,
]):
if not all(
(0 <= d.index < torch.cuda.device_count() if d.type == "cuda" else True)
for d in local_devices
):
raise ValueError(
f"Invalid device in TensorPipe options on {name}:\n"
f"device_maps = {options.device_maps},\n"
@ -235,12 +235,11 @@ def _tensorpipe_check_remote_device_maps(agent, options):
remote_device_count = all_device_counts[remote_name]
if remote_name in device_maps:
device_map = device_maps[remote_name]
key_set = set(device_map.keys())
val_set = set(device_map.values())
if not all([
min(val_set) >= 0,
max(val_set) < remote_device_count # check remote range
]):
if not all(
(0 <= d.index < remote_device_count if d.type == "cuda" else True)
for d in val_set
):
raise ValueError(
f"Invalid device_map configuration on {name} "
f"for {remote_name}, remote device out of range:\n"

View file

@ -9,19 +9,19 @@ from typing import Dict, List, Optional, Union
DeviceType = Union[int, str, torch.device]
def _to_device_index(device: DeviceType) -> int:
def _to_device(device: DeviceType) -> torch.device:
device = torch.device(device)
if device.type != "cuda":
raise ValueError(
"`set_devices` expect a list of CUDA devices, but got "
f"device type {device.type}."
)
return device.index
return device
def _to_device_index_map(device_map: Dict[DeviceType, DeviceType]) -> Dict[int, int]:
device_index_map : Dict[int, int] = {}
reverse_map : Dict[int, int] = {}
def _to_device_map(device_map: Dict[DeviceType, DeviceType]) -> Dict[torch.device, torch.device]:
full_device_map : Dict[torch.device, torch.device] = {}
reverse_map : Dict[torch.device, torch.device] = {}
for k in device_map:
v = device_map[k]
k, v = torch.device(k), torch.device(v)
@ -31,18 +31,18 @@ def _to_device_index_map(device_map: Dict[DeviceType, DeviceType]) -> Dict[int,
f"but got device pair {k}: {v}"
)
if v.index in reverse_map:
if v in reverse_map:
raise ValueError(
"`device_map` only supports 1-to-1 mapping, "
f"trying to map {k} and {reverse_map[v.index]} to {v.index}"
f"trying to map {k} and {reverse_map[v]} to {v}"
)
device_index_map[k.index] = v.index
reverse_map[v.index] = k.index
return device_index_map
full_device_map[k] = v
reverse_map[v] = k
return full_device_map
def _to_device_index_list(devices: List[DeviceType]) -> List[int]:
return list(map(_to_device_index, devices))
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
return list(map(_to_device, devices))
@ -90,13 +90,13 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
_transports: List = None,
_channels: List = None,
):
device_index_maps = (
full_device_maps = (
{} if device_maps is None else
{k : _to_device_index_map(v) for k, v in device_maps.items()}
{k : _to_device_map(v) for k, v in device_maps.items()}
)
device_index_list = (
full_device_list = (
[] if devices is None else
_to_device_index_list(devices)
_to_device_list(devices)
)
super().__init__(
num_worker_threads,
@ -104,8 +104,8 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
_channels,
rpc_timeout,
init_method,
device_index_maps,
device_index_list,
full_device_maps,
full_device_list,
)
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
@ -152,18 +152,18 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
>>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
"""
device_index_map = _to_device_index_map(device_map)
full_device_map = _to_device_map(device_map)
curr_device_maps = super().device_maps
if to in curr_device_maps:
for k, v in device_index_map.items():
for k, v in full_device_map.items():
if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
raise ValueError(
"`set_device_map` only supports 1-to-1 mapping, trying"
f" to map {k} to {v} and {curr_device_maps[to][k]}"
)
super()._set_device_map(to, device_index_map)
super()._set_device_map(to, full_device_map)
def set_devices(self, devices: List[DeviceType]):
r"""
@ -175,4 +175,4 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
devices (List of int, str, or torch.device): local devices used by
the TensorPipe RPC agent.
"""
self.devices = _to_device_index_list(devices)
self.devices = _to_device_list(devices)