diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 14b37bd668d..83670d6d0eb 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -46,7 +46,7 @@ class TestE2EBase : public ::testing::Test { return c10::StrongTypePtr( nullptr, c10::DictType::create( - c10::IntType::get(), c10::IntType::get())); + c10::StringType::get(), c10::StringType::get())); } return c10::StrongTypePtr( nullptr, c10::TensorType::create(at::Tensor())); diff --git a/test/cpp/rpc/test_tensorpipe_serialization.cpp b/test/cpp/rpc/test_tensorpipe_serialization.cpp index 48cffa71345..5046d75b9ca 100644 --- a/test/cpp/rpc/test_tensorpipe_serialization.cpp +++ b/test/cpp/rpc/test_tensorpipe_serialization.cpp @@ -11,6 +11,8 @@ TEST(TensorpipeSerialize, Base) { // Sender serializes + auto lazyStreamCtx = + std::make_shared(c10::kCPU); at::Tensor t1 = torch::ones({1024}, at::ScalarType::Int); at::Tensor t2 = torch::ones({1024}, at::ScalarType::Float); std::vector tensors{t1, t2}; @@ -26,7 +28,7 @@ TEST(TensorpipeSerialize, Base) { torch::distributed::rpc::TensorpipeWriteBuffers sendingTpBuffers; std::tie(sendingTpMessage, sendingTpBuffers) = torch::distributed::rpc::tensorpipeSerialize( - std::move(sendingRpcMessage)); + std::move(sendingRpcMessage), {}, lazyStreamCtx); // Mimic receiving message descriptor: recvingTpDescriptor is a copy of // sendingTpMessage except for the data pointers which are left null. @@ -59,7 +61,8 @@ TEST(TensorpipeSerialize, Base) { tensorpipe::Allocation recvingTpAllocation; torch::distributed::rpc::TensorpipeReadBuffers recvingTpBuffers; std::tie(recvingTpAllocation, recvingTpBuffers) = - torch::distributed::rpc::tensorpipeAllocate(recvingTpDescriptor); + torch::distributed::rpc::tensorpipeAllocate( + recvingTpDescriptor, lazyStreamCtx); // Mimic tensorpipe data transfer EXPECT_EQ( @@ -101,6 +104,8 @@ TEST(TensorpipeSerialize, Base) { TEST(TensorpipeSerialize, RecopySparseTensors) { // Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows. + auto lazyStreamCtx = + std::make_shared(c10::kCPU); constexpr size_t k1K = 1024; at::Tensor main = torch::randn({k1K, k1K}); at::Tensor tiny = main.select(0, 2); // Select a row in the middle @@ -118,7 +123,7 @@ TEST(TensorpipeSerialize, RecopySparseTensors) { torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers; std::tie(sendingTpMessage, tpBuffers) = torch::distributed::rpc::tensorpipeSerialize( - std::move(sendingRpcMessage)); + std::move(sendingRpcMessage), {}, lazyStreamCtx); EXPECT_EQ(tpBuffers.tensors.size(), 2); EXPECT_EQ(sendingTpMessage.tensors.size(), 2); @@ -135,6 +140,8 @@ TEST(TensorpipeSerialize, RecopySparseTensors) { } TEST(TensorpipeSerialize, NoDeleterTensors) { + auto lazyStreamCtx = + std::make_shared(c10::kCPU); std::vector blob1{.8, .2}; std::vector blob2{.7, .5, .9}; at::Tensor t1 = torch::from_blob((float*)(blob1.data()), blob1.size()); @@ -150,7 +157,7 @@ TEST(TensorpipeSerialize, NoDeleterTensors) { torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers; std::tie(sendingTpMessage, tpBuffers) = torch::distributed::rpc::tensorpipeSerialize( - std::move(sendingRpcMessage)); + std::move(sendingRpcMessage), {}, lazyStreamCtx); EXPECT_EQ(tpBuffers.copiedTensors.size(), 2); EXPECT_EQ(sendingTpMessage.tensors.size(), 2); diff --git a/torch/_C/_distributed_rpc.pyi b/torch/_C/_distributed_rpc.pyi index 83e5008ec0b..17bdcba5d4d 100644 --- a/torch/_C/_distributed_rpc.pyi +++ b/torch/_C/_distributed_rpc.pyi @@ -41,7 +41,7 @@ class RpcAgent: @overload def get_worker_info(self, workerName: str) -> WorkerInfo: ... def get_worker_infos(self) -> List[WorkerInfo]: ... - def _get_device_map(self, dst: WorkerInfo) -> Dict[int, int]: ... + def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ... def get_debug_info(self) -> Dict[str, str]: ... def get_metrics(self) -> Dict[str, str]: ... @@ -91,14 +91,14 @@ class ProcessGroupAgent(RpcAgent): @overload def get_worker_info(self, id: int) -> WorkerInfo: ... def get_worker_infos(self) -> List[WorkerInfo]: ... - def _get_device_map(self, dst: WorkerInfo) -> Dict[int, int]: ... + def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ... def join(self): ... def shutdown(self): ... def sync(self): ... class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): num_worker_threads: int - device_maps: Dict[str, Dict[int, int]] + device_maps: Dict[str, Dict[torch.device, torch.device]] def __init__( self, num_worker_threads: int, @@ -106,9 +106,9 @@ class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): _channels: Optional[List], rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC, init_method: str = _DEFAULT_INIT_METHOD, - device_maps: Dict[str, Dict[int, int]] = dict(), - devices: List[int] = list()): ... - def _set_device_map(self, to: str, device_map: Dict[int, int]): ... + device_maps: Dict[str, Dict[torch.device, torch.device]] = dict(), + devices: List[torch.device] = list()): ... + def _set_device_map(self, to: str, device_map: Dict[torch.device, torch.device]): ... class TensorPipeAgent(RpcAgent): def __init__( @@ -129,8 +129,8 @@ class TensorPipeAgent(RpcAgent): @overload def get_worker_info(self, id: int) -> WorkerInfo: ... def get_worker_infos(self) -> List[WorkerInfo]: ... - def _get_device_map(self, dst: WorkerInfo) -> Dict[int, int]: ... - def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ... + def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ... + def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[torch.device, torch.device]]): ... def _is_current_rpc_agent_set() -> bool: ... def _get_current_rpc_agent()-> RpcAgent: ... diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 6e79246c48f..981b1a57285 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -14,7 +14,7 @@ RecvRpcBackward::RecvRpcBackward( const AutogradMetadata& autogradMetadata, ContextPtr autogradContext, rpc::worker_id_t fromWorkerId, - std::unordered_map deviceMap) + std::unordered_map deviceMap) : autogradMetadata_(autogradMetadata), // NOLINTNEXTLINE(performance-move-const-arg) autogradContext_(std::move(autogradContext)), diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h index 69be98c928e..46bdb297cdf 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h @@ -23,7 +23,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { const AutogradMetadata& autogradMetadata, std::shared_ptr autogradContext, rpc::worker_id_t fromWorkerId, - std::unordered_map deviceMap); + std::unordered_map deviceMap); torch::autograd::variable_list apply( torch::autograd::variable_list&& grads) override; @@ -41,7 +41,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { rpc::worker_id_t fromWorkerId_; // Device mapping for tensors sent over RPC. - const std::unordered_map deviceMap_; + const std::unordered_map deviceMap_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp index 5aea96fa0c8..40ddee7594c 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp @@ -19,7 +19,7 @@ RpcWithAutograd::RpcWithAutograd( MessageType messageType, const AutogradMetadata& autogradMetadata, rpc::Message&& wrappedMessage, - std::unordered_map deviceMap) + std::unordered_map deviceMap) : fromWorkerId_(fromWorkerId), messageType_(messageType), autogradMetadata_(autogradMetadata), @@ -39,7 +39,7 @@ RpcWithAutograd::RpcWithAutograd( std::unique_ptr wrappedRpc, MessageType wrappedMessageType, std::vector tensors, - std::unordered_map deviceMap) + std::unordered_map deviceMap) : fromWorkerId_(fromWorkerId), messageType_(messageType), autogradMetadata_(autogradMetadata), @@ -61,9 +61,9 @@ Message RpcWithAutograd::toMessageImpl() && { TORCH_INTERNAL_ASSERT(!payload.empty()); // Convert deviceMap to c10::Dict for serialization. - c10::Dict deviceMap; + c10::Dict deviceMap; for (const auto& mapEntry : deviceMap_) { - deviceMap.insert(mapEntry.first, mapEntry.second); + deviceMap.insert(mapEntry.first.str(), mapEntry.second.str()); } std::vector ivalues{wrappedMessageType, @@ -109,10 +109,10 @@ std::unique_ptr RpcWithAutograd::fromMessage( AutogradMetadata autogradMetadata( tupleElements[1].toInt(), tupleElements[2].toInt()); worker_id_t workerId = tupleElements[3].toInt(); - auto c10DeviceMap = tupleElements[4].to>(); + auto c10DeviceMap = tupleElements[4].to>(); // Convert to regular map. - std::unordered_map deviceMap; + std::unordered_map deviceMap; for (const auto& mapEntry : c10DeviceMap) { deviceMap.insert({mapEntry.key(), mapEntry.value()}); } @@ -169,7 +169,7 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const { return fromWorkerId_; } -const std::unordered_map& RpcWithAutograd:: +const std::unordered_map& RpcWithAutograd:: deviceMap() { return deviceMap_; } diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h index f4728ea37c6..08ce35d6096 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h @@ -19,7 +19,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { rpc::MessageType messageType, const AutogradMetadata& autogradMetadata, rpc::Message&& wrappedMessage, - std::unordered_map deviceMap = {}); + std::unordered_map deviceMap = {}); // Used when receiving an RPC over the wire. RpcWithAutograd( @@ -29,7 +29,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, std::vector tensors, - std::unordered_map deviceMap = {}); + std::unordered_map deviceMap = {}); rpc::Message toMessageImpl() && override; @@ -55,7 +55,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { rpc::worker_id_t fromWorkerId() const; // Retrieve the device map. - const std::unordered_map& deviceMap(); + const std::unordered_map& deviceMap(); private: // WorkerId from which this RPC originated. This is necessary for knowing @@ -90,7 +90,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { std::vector tensors_; // Device mapping for tensors that are sent across an RPC to another node. - std::unordered_map deviceMap_; + std::unordered_map deviceMap_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 205e50945dd..21785b631d3 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -53,7 +53,7 @@ ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, rpc::worker_id_t fromWorkerId, - const std::unordered_map& deviceMap) { + const std::unordered_map& deviceMap) { // Initialize autograd context if necessary. auto& autogradContainer = DistAutogradContainer::getInstance(); auto autogradContext = @@ -105,7 +105,7 @@ Message getMessageWithAutograd( torch::distributed::rpc::Message&& wrappedRpcMsg, MessageType msgType, bool forceGradRecording, - const std::unordered_map& deviceMap) { + const std::unordered_map& deviceMap) { auto& autogradContainer = DistAutogradContainer::getInstance(); // If there is no valid context and no tensor requires grads, send original diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 013558252fc..0bf542be3ef 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -31,7 +31,7 @@ TORCH_API ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, rpc::worker_id_t fromWorkerId, - const std::unordered_map& deviceMap); + const std::unordered_map& deviceMap); // This method is a wrapper utility used internally to wrap autograd info // and attach autograd function for each type of rpc call if it has valid @@ -44,7 +44,7 @@ TORCH_API rpc::Message getMessageWithAutograd( rpc::Message&& wrappedRpcMsg, rpc::MessageType msgType, bool forceGradRecording = false, - const std::unordered_map& deviceMap = + const std::unordered_map& deviceMap = {}); // Send message after autograd checking diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 35e40520cfa..e479badb386 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -635,16 +635,15 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { optional>, float, std::string, - std::unordered_map, - std::vector>(), + std::unordered_map, + std::vector>(), py::arg("num_worker_threads") = kDefaultNumWorkerThreads, py::arg("_transports") = optional>(), py::arg("_channels") = optional>(), py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, py::arg("init_method") = kDefaultInitMethod, - py::arg("device_maps") = - std::unordered_map(), - py::arg("devices") = std::vector()) + py::arg("device_maps") = std::unordered_map(), + py::arg("devices") = std::vector()) .def_readwrite( "num_worker_threads", &TensorPipeRpcBackendOptions::numWorkerThreads, @@ -722,8 +721,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { py::call_guard()) .def( "_get_device_map", - (tensorpipe::DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) - const) & + (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) & TensorPipeAgent::getDeviceMap, py::call_guard()) .def( diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 34d96210043..71c7f3f042f 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -270,7 +270,7 @@ std::shared_ptr ProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds, - const std::unordered_map& deviceMap) { + const std::unordered_map& /* unused */) { // Throw if we previously encountered an exception in ::listenLoop. { std::unique_lock guard(listenLoopExceptionMutex_); diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 2ee28a3b58d..2bb8d6f90d9 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -94,8 +94,8 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout, - const std::unordered_map& deviceMap = - {}) override; + const std::unordered_map& deviceMap = {}) + override; // put SendWork into a queue and notify the worker thread virtual void enqueueSend(SendWork work); diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 7ca2a4c91ad..cd093a45116 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -361,7 +361,7 @@ void RequestCallbackNoPython::processForwardAutogradReq( // Need to reverse the device map for the backward pass of distributed // autograd. - std::unordered_map reverseDeviceMap; + std::unordered_map reverseDeviceMap; for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); } diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 83a9965792e..50d69b277eb 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -287,7 +287,7 @@ bool RpcAgent::isGILProfilingEnabled() { return profilingEnabled_.load(); } -std::unordered_map RpcAgent::getDeviceMap( +std::unordered_map RpcAgent::getDeviceMap( const WorkerInfo& /* unused */) const { // Default implementation has no device map. return {}; diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index 4dd5e7c8953..7991a543c11 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -164,8 +164,7 @@ class TORCH_API RpcAgent { const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout, - const std::unordered_map& deviceMap = - {}) = 0; + const std::unordered_map& deviceMap = {}) = 0; // Retries sending the message up to maxRetries times until an ACK is // receieved. The duration between consecutive sends is increased over @@ -265,7 +264,7 @@ class TORCH_API RpcAgent { std::shared_ptr getTypeResolver(); // Retrieves the device map for the provided destination worker. - virtual std::unordered_map getDeviceMap( + virtual std::unordered_map getDeviceMap( const WorkerInfo& dst) const; protected: diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 6c6f2cebb68..b06191585cc 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -314,7 +314,7 @@ void OwnerRRef::recordAllStreams( void OwnerRRef::blockAllStreams(std::shared_ptr& ctx) { if (ctx) { for (c10::Event& event : events_) { - event.block(ctx->getStream(event.device_index())); + event.block(ctx->getStream(event.device())); } } } diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 12ea96373f9..4a0ba70b5da 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -41,9 +41,9 @@ const std::string kClientActiveCalls = "agent.client_active_calls"; const std::string kServerActiveCalls = "agent.server_active_calls"; const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls"; -std::vector getDevicesForTensors( +std::vector getDevicesForTensors( const std::vector& tensors, - const tensorpipe::DeviceMap& deviceMap, + const DeviceMap& deviceMap, const std::string& remoteName) { // If the deviceMap is overridden, use that instead. const auto errStr = c10::str( @@ -53,45 +53,46 @@ std::vector getDevicesForTensors( "configure device mapping. ", "Request device mapping is not available for destination ", remoteName); - std::vector deviceIndices; - deviceIndices.reserve(tensors.size()); + std::vector devices; + devices.reserve(tensors.size()); bool hasCudaTensor = false; for (const auto& t : tensors) { if (t.device().is_cpu()) { - deviceIndices.push_back(-1); + devices.emplace_back(c10::kCPU); } else { - const auto deviceIter = deviceMap.find(t.device().index()); + const auto deviceIter = deviceMap.find(t.device()); TORCH_CHECK( deviceIter != deviceMap.end(), errStr, " for device ", t.device(), " but received a tensor on that device."); - deviceIndices.push_back(deviceIter->second); + devices.push_back(deviceIter->second); hasCudaTensor = true; } } if (!hasCudaTensor) { - deviceIndices.clear(); + devices.clear(); } - return deviceIndices; + return devices; } // A helper function that first creates a LazyStreamContext and then grabs a // CUDA stream for each device in the given device list. std::shared_ptr createCalleeStreamContext( - std::vector devices) { - auto ctx = createLazyStreamContext(); - for (const auto& device : devices) { + const std::vector& devices) { + auto ctx = std::make_shared( + devices.empty() ? c10::kCPU : devices[0].type()); + for (const c10::Device& device : devices) { ctx->getStream(device); } return ctx; } // Retrieve local devices (i.e., device keys) from the given map. -std::unordered_set getLocalDevices( - const std::unordered_map& deviceMap) { - std::unordered_set deviceSet; +std::unordered_set getLocalDevices( + const std::unordered_map& deviceMap) { + std::unordered_set deviceSet; for (const auto& entry : deviceMap) { for (const auto& device : entry.second) { deviceSet.insert(device.first); @@ -103,9 +104,9 @@ std::unordered_set getLocalDevices( // 1) checks there is no duplication in the devices field // 2) checks all local devices in the deviceSet are included in deviceOpt void checkValidDevicesOption( - const std::unordered_set& deviceSet, - const std::vector& deviceOpt) { - std::unordered_set optsDeviceSet( + const std::unordered_set& deviceSet, + const std::vector& deviceOpt) { + std::unordered_set optsDeviceSet( deviceOpt.begin(), deviceOpt.end()); // no duplications are allowed in opts_.devices @@ -114,18 +115,17 @@ void checkValidDevicesOption( "Detected duplication in TensorPipeRpcBackendOptions devices field."); // opts_.devices must be a superset of local devices in reverseDeviceMaps_ - std::vector cut; - std::set_difference( - deviceSet.begin(), - deviceSet.end(), - optsDeviceSet.begin(), - optsDeviceSet.end(), - std::back_inserter(cut)); + std::unordered_set cut; + for (const c10::Device& device : deviceSet) { + if (optsDeviceSet.find(device) == optsDeviceSet.end()) { + cut.insert(device); + } + } if (!cut.empty()) { std::ostringstream oss; std::copy( - cut.begin(), cut.end(), std::ostream_iterator(oss, ", ")); + cut.begin(), cut.end(), std::ostream_iterator(oss, ", ")); TORCH_CHECK( false, "The devices field in TensorPipeRpcBackendOptions must either be " @@ -511,8 +511,7 @@ TensorPipeAgent::~TensorPipeAgent() { } void TensorPipeAgent::setReverseDeviceMaps( - const std::unordered_map& - reverseDeviceMaps) { + const std::unordered_map& reverseDeviceMaps) { reverseDeviceMaps_ = reverseDeviceMaps; // If devices wasn't specified in the options, update devices_ with local @@ -708,11 +707,9 @@ void TensorPipeAgent::pipeRead( void TensorPipeAgent::pipeWrite( const std::shared_ptr& pipe, Message&& rpcMessage, - std::vector&& devices, + std::vector&& devices, std::shared_ptr ctx, - std::function fn, - const std::unordered_map& - deviceMap) noexcept { + std::function fn) noexcept { tensorpipe::Message tpMessage; TensorpipeWriteBuffers tpBuffers; @@ -748,7 +745,7 @@ void TensorPipeAgent::sendCompletedResponseMessage( std::move(*futureResponseMessage->value().toCustomClass()); responseMessage.setId(messageId); - std::vector devices; + std::vector devices; try { devices = getDevicesForRemote(pipe->getRemoteName(), responseMessage); } catch (const std::exception& e) { @@ -760,19 +757,17 @@ void TensorPipeAgent::sendCompletedResponseMessage( // FIXME: skipping this check when ctxDevices is empty to allow // RRef.to_here(). for (const auto& tensor : responseMessage.tensors()) { - const auto device = tensor.device().index(); - if (device != -1 && ctxDevices.find(device) == ctxDevices.end()) { + const auto device = tensor.device(); + if (!device.is_cpu() && ctxDevices.find(device) == ctxDevices.end()) { std::ostringstream oss; std::copy( ctxDevices.begin(), ctxDevices.end(), - // interpreting c10::DeviceIndex as int32_t to avoid printing - // it as a char. - std::ostream_iterator(oss, ", ")); + std::ostream_iterator(oss, ", ")); responseMessage = createExceptionResponse( c10::str( "RPC detected that a user-function output tensor on device ", - int32_t(device), + device, ". This device is not one of the input tensor devices: ", oss.str(), "which is not yet supported. Please file a feature request " @@ -913,7 +908,7 @@ std::shared_ptr TensorPipeAgent::send( const WorkerInfo& toWorkerInfo, Message&& requestMessage, const float rpcTimeoutSeconds, - const std::unordered_map& deviceMap) { + const DeviceMap& deviceMap) { TORCH_CHECK( requestMessage.isRequest(), "TensorPipeAgent::send(..) is only for sending requests."); @@ -960,7 +955,7 @@ std::shared_ptr TensorPipeAgent::send( // Get devices for tensors in the request message. This can throw if device // maps are not configured properly for this request. - std::vector devices; + std::vector devices; if (deviceMap.empty()) { devices = getDevicesForRemote(clientPipe.pipe_->getRemoteName(), requestMessage); @@ -1006,7 +1001,8 @@ std::shared_ptr TensorPipeAgent::send( VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #" << messageId << " to " << clientPipe.pipe_->getRemoteName(); - auto ctx = createLazyStreamContext(); + auto ctx = std::make_shared( + devices_.empty() ? c10::kCPU : devices_[0].type()); ctx->waitForCurrentStreams(requestMessage.tensors()); pipeWrite( clientPipe.pipe_, @@ -1093,8 +1089,7 @@ std::shared_ptr TensorPipeAgent::send( std::move(ctx)); } }); - }, - deviceMap); + }); return futureResponseMessage->jitFuture; } @@ -1419,7 +1414,7 @@ void TensorPipeAgent::markFutureWithError( } } -std::vector TensorPipeAgent::getDevicesForRemote( +std::vector TensorPipeAgent::getDevicesForRemote( const std::string& remoteName, const Message& message) const { const auto& deviceMaps = @@ -1449,8 +1444,7 @@ std::vector TensorPipeAgent::getDevicesForRemote( } } -tensorpipe::DeviceMap TensorPipeAgent::getDeviceMap( - const WorkerInfo& dst) const { +DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const { auto it = opts_.deviceMaps.find(dst.name_); if (it == opts_.deviceMaps.end()) { return {}; diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index 24d00e8c45e..b1f40ffd78f 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -37,14 +37,14 @@ namespace channel { class Context; } // namespace channel -using DeviceMap = std::unordered_map; - } // namespace tensorpipe namespace torch { namespace distributed { namespace rpc { +using DeviceMap = std::unordered_map; + struct LazyStreamContext; using steady_clock_time_point = @@ -86,8 +86,8 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions { optional> channels, float rpc_timeout, std::string init_method, - std::unordered_map device_maps = {}, - std::vector devices = {}) + std::unordered_map device_maps = {}, + std::vector devices = {}) : RpcBackendOptions(rpc_timeout, init_method), numWorkerThreads(numWorkerThreads), transports(std::move(transports)), @@ -119,15 +119,20 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions { } } - void setDeviceMap( - const std::string& workerName, - const tensorpipe::DeviceMap& deviceMap) { + void setDeviceMap(const std::string& workerName, const DeviceMap& deviceMap) { auto iter = deviceMaps.find(workerName); if (iter == deviceMaps.end()) { deviceMaps[workerName] = deviceMap; } else { for (auto& entry : deviceMap) { - iter->second[entry.first] = entry.second; + // c10::Device has no default constructor, hence map[device] dosn't work + // In C++-17 we can use insert_or_assign. + auto entryIter = iter->second.find(entry.first); + if (entryIter == iter->second.end()) { + iter->second.emplace(entry.first, entry.second); + } else { + entryIter->second = entry.second; + } } } } @@ -135,8 +140,8 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions { int numWorkerThreads; const optional> transports; const optional> channels; - std::unordered_map deviceMaps; - std::vector devices; + std::unordered_map deviceMaps; + std::vector devices; }; // Struct to track the network source metrics @@ -175,8 +180,7 @@ class TensorPipeAgent : public RpcAgent { const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout, - const std::unordered_map& deviceMap = - {}) override; + const DeviceMap& deviceMap = {}) override; // join() and sync() would be deprecated - // https://github.com/pytorch/pytorch/issues/27647 @@ -191,14 +195,13 @@ class TensorPipeAgent : public RpcAgent { const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override; std::vector getWorkerInfos() const override; void setReverseDeviceMaps( - const std::unordered_map& - reverseDeviceMaps); + const std::unordered_map& reverseDeviceMaps); std::unordered_map getMetrics() override; void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override; - tensorpipe::DeviceMap getDeviceMap(const WorkerInfo& dest) const override; + DeviceMap getDeviceMap(const WorkerInfo& dest) const override; using NetworkDataDict = std::unordered_map; @@ -239,10 +242,9 @@ class TensorPipeAgent : public RpcAgent { void pipeWrite( const std::shared_ptr&, Message&& message, - std::vector&& devices, + std::vector&& devices, std::shared_ptr ctx, - std::function, - const tensorpipe::DeviceMap& deviceMap = {}) noexcept; + std::function) noexcept; // Callback of listener accept() void onListenerAccepted( @@ -269,7 +271,7 @@ class TensorPipeAgent : public RpcAgent { uint64_t requestSize, const std::string& destWorkerName); - inline std::vector getDevicesForRemote( + inline std::vector getDevicesForRemote( const std::string& remoteName, const Message& message) const; @@ -280,14 +282,9 @@ class TensorPipeAgent : public RpcAgent { // then, it ends up printing a log message, which may worry the user. To solve // both issues we use a separate atomic flag to know the status of the future. struct AtomicJitFuture { - explicit AtomicJitFuture(const std::vector& devices) { - std::vector fullDevices; - fullDevices.reserve(devices.size()); - for (const c10::DeviceIndex index : devices) { - fullDevices.emplace_back(c10::kCUDA, index); - } + explicit AtomicJitFuture(const std::vector& devices) { jitFuture = std::make_shared( - at::AnyClassType::get(), std::move(fullDevices)); + at::AnyClassType::get(), devices); } std::atomic_flag isComplete = ATOMIC_FLAG_INIT; @@ -310,11 +307,11 @@ class TensorPipeAgent : public RpcAgent { }; const TensorPipeRpcBackendOptions opts_; - std::unordered_map reverseDeviceMaps_; + std::unordered_map reverseDeviceMaps_; // Local devices used by this agent. If application didn't specify this // field, it will be initialized using corresponding local devices in // opts_.deviceMaps and reverseDeviceMaps_; - std::vector devices_; + std::vector devices_; ThreadPool threadPool_; std::shared_ptr context_; diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 3a15a2b1258..ba10d6d3819 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -42,7 +42,7 @@ inline c10::Device indexToDevice(c10::DeviceIndex index) { std::tuple tensorpipeSerialize( Message&& rpcMessage, - std::vector deviceIndices, + std::vector devices, const std::shared_ptr& ctx) { tensorpipe::Message tpMessage; TensorpipeWriteBuffers buffers; @@ -89,10 +89,9 @@ std::tuple tensorpipeSerialize( // tensor to CPU. const auto& tensorData = jit::getWriteableTensorData(tensorDataVec[i], /* toCpu */ false); - tensorpipe::Device targetDevice = - deviceIndices.empty() || deviceIndices[i] == -1 + tensorpipe::Device targetDevice = devices.empty() || devices[i].is_cpu() ? tensorpipe::Device{tensorpipe::kCpuDeviceType, 0} - : tensorpipe::Device{tensorpipe::kCudaDeviceType, deviceIndices[i]}; + : tensorpipe::Device{tensorpipe::kCudaDeviceType, devices[i].index()}; // Enforce memory copy if tensor is created from torch::from_blob, means // that the tensor doesn't own the memory. @@ -126,8 +125,8 @@ std::tuple tensorpipeSerialize( tpMessage.tensors.push_back(std::move(tensor)); #ifdef USE_CUDA_NOT_ROCM } else if (tensorDataVec[i].device().is_cuda()) { - auto stream = at::cuda::CUDAStream( - ctx->getStream(tensorDataVec[i].device().index())); + auto stream = + at::cuda::CUDAStream(ctx->getStream(tensorDataVec[i].device())); tensorpipe::CudaBuffer buffer; buffer.ptr = tensorPtr; buffer.stream = stream.stream(); @@ -209,8 +208,8 @@ std::pair tensorpipeAllocate( tpAllocation.tensors[tensorIdx].buffer = buffer; #ifdef USE_CUDA_NOT_ROCM } else if (tensor.targetDevice->type == tensorpipe::kCudaDeviceType) { - auto deviceIndex = tensor.targetDevice->index; - auto stream = at::cuda::CUDAStream(ctx->getStream(deviceIndex)); + c10::Device device(c10::kCUDA, tensor.targetDevice->index); + auto stream = at::cuda::CUDAStream(ctx->getStream(device)); // CUDACachingAllocator will call recordStream accordingly on the current // stream. at::cuda::CUDAStreamGuard guard(stream); diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.h b/torch/csrc/distributed/rpc/tensorpipe_utils.h index 42f7fe8ea12..346e420957c 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.h +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.h @@ -45,30 +45,13 @@ struct TensorpipeReadBuffers { std::vector tensors; }; -inline std::shared_ptr createLazyStreamContext() { - return createLazyStreamContext( -#ifdef USE_CUDA_NOT_ROCM - c10::DeviceType::CUDA, - [](c10::DeviceIndex index) { - return at::cuda::getStreamFromPool( - /* isHighPriority */ false, /* device */ index); - }, - [](c10::DeviceIndex index) { - return at::cuda::getCurrentCUDAStream(index); - } -#else - c10::DeviceType::CPU, nullptr, nullptr -#endif - ); -} - // Convert an RPC message into a TensorPipe message, plus a holder to all the // data that must be kept alive while the write is performed asynchronously. TORCH_API std::tuple tensorpipeSerialize( Message&& rpcMessage, - std::vector devices = {}, - const std::shared_ptr& = createLazyStreamContext()); + std::vector devices, + const std::shared_ptr& ctx); // Allocate the buffers that will hold the incoming data. They will be managed // by the returned holder, which must be kept alive until the asynchronous read @@ -77,7 +60,7 @@ tensorpipeSerialize( TORCH_API std::pair tensorpipeAllocate( const tensorpipe::Descriptor& tpDescriptor, - const std::shared_ptr& ctx = createLazyStreamContext()); + const std::shared_ptr& ctx); // Convert a TensorPipe message back into an RPC message. This requires the data // to be available and can thus only be performed once the asynchronous read has diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index 909f09fb408..8541f89a604 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -62,7 +62,7 @@ std::shared_ptr FaultyProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds, - const std::unordered_map& deviceMap) { + const std::unordered_map& /* unused */) { // We only fail control messages that have been specified by the test case. // For all other messages, we just send them without any failures. if (!shouldFailMessage(message.type())) { diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h index 94701f0c10d..c7acfd0029f 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h @@ -48,8 +48,8 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, - const std::unordered_map& deviceMap = - {}) override; + const std::unordered_map& deviceMap = {}) + override; protected: // This function checks the messageTypesToFail_ to determine whether to use diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index c7e68270feb..1f178f0b5cb 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -175,7 +175,7 @@ std::unique_ptr deserializeResponse( // Need to reverse the device map for the backward pass of distributed // autograd. - std::unordered_map reverseDeviceMap; + std::unordered_map reverseDeviceMap; for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); } diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h index f1d7220b244..3dcfd6d6dfa 100644 --- a/torch/csrc/distributed/rpc/utils.h +++ b/torch/csrc/distributed/rpc/utils.h @@ -89,8 +89,6 @@ TORCH_API void populateRemoteProfiledEvents( const std::vector>& eventLists); -using stream_factory_t = std::function; - // A general device context class for both CPU and CUDA. If CUDA is not // available, all CUDA-related methods will be no-ops. struct TORCH_API LazyStreamContext { @@ -99,39 +97,26 @@ struct TORCH_API LazyStreamContext { LazyStreamContext& operator=(const LazyStreamContext& rhs) = delete; LazyStreamContext& operator=(LazyStreamContext&& rhs) & = delete; - LazyStreamContext( - c10::DeviceType device_type, - stream_factory_t stream_creator, - stream_factory_t current_stream_provider) - : device_type_(device_type), - stream_creator_(std::move(stream_creator)), - current_stream_provider_(std::move(current_stream_provider)) {} + explicit LazyStreamContext(c10::DeviceType device_type) + : impl_(device_type) {} // let streams in this context wiat for current streams. void waitForCurrentStreams(const std::vector& tensors = {}) { - if (!stream_creator_) { - // Since the stream_creator is empty the device doesn't support streams - return; - } for (const auto& tensor : tensors) { if (tensor.is_cuda()) { - getStream(tensor.device().index()); + getStream(tensor.device()); } } for (const auto& entry : streams_) { - c10::Event event{device_type_}; - event.record(current_stream_provider_(entry.first)); + c10::Event event{impl_.type()}; + event.record(impl_.getStream(entry.first)); event.block(entry.second); } } // get all streams used in this context std::vector getReservedStreams() const { - if (!stream_creator_) { - // Since the stream_creator is empty the device doesn't support streams - return {}; - } std::vector reservedStreams; reservedStreams.reserve(streams_.size()); for (const auto& entry : streams_) { @@ -142,25 +127,19 @@ struct TORCH_API LazyStreamContext { // get a stream for the given device. If it is the first time using that // device, allocate a new stream and store it in the map. - c10::Stream getStream(c10::DeviceIndex index) { - if (!stream_creator_) { - throw std::runtime_error(c10::str( - "Attempting to access device stream of device ", - index, - ", but the device doesn't support streams")); - } - auto iter = streams_.find(index); + c10::Stream getStream(c10::Device device) { + auto iter = streams_.find(device); if (iter == streams_.end()) { - auto stream = stream_creator_(index); - streams_.emplace(index, stream); + auto stream = impl_.getStreamFromGlobalPool(device); + streams_.emplace(device, stream); return stream; } else { return iter->second; } } - std::set devices() const { - std::set devices; + std::unordered_set devices() const { + std::unordered_set devices; for (const auto& entry : streams_) { devices.insert(entry.first); } @@ -168,26 +147,14 @@ struct TORCH_API LazyStreamContext { } c10::DeviceType deviceType() const { - return device_type_; + return impl_.type(); } private: - std::unordered_map streams_; - c10::DeviceType device_type_; - stream_factory_t stream_creator_; - stream_factory_t current_stream_provider_; + const c10::impl::VirtualGuardImpl impl_; + std::unordered_map streams_; }; -inline std::shared_ptr createLazyStreamContext( - c10::DeviceType device_type, - stream_factory_t stream_creator, - stream_factory_t current_stream_provider) { - return std::make_shared( - device_type, - std::move(stream_creator), - std::move(current_stream_provider)); -} - } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 3486bb3da04..403953e7f7a 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -206,12 +206,12 @@ def _tensorpipe_check_local_device_maps(name, options): f"Invalid device_map configuration for {worker_name}, " f"not 1-to-1 mapping:\ndevice_maps = {device_map}" ) - local_devices.update(options.device_maps[worker_name].keys()) + local_devices.update(key_set) - if len(local_devices) > 0 and not all([ - max(local_devices) < torch.cuda.device_count(), - min(local_devices) >= 0, - ]): + if not all( + (0 <= d.index < torch.cuda.device_count() if d.type == "cuda" else True) + for d in local_devices + ): raise ValueError( f"Invalid device in TensorPipe options on {name}:\n" f"device_maps = {options.device_maps},\n" @@ -235,12 +235,11 @@ def _tensorpipe_check_remote_device_maps(agent, options): remote_device_count = all_device_counts[remote_name] if remote_name in device_maps: device_map = device_maps[remote_name] - key_set = set(device_map.keys()) val_set = set(device_map.values()) - if not all([ - min(val_set) >= 0, - max(val_set) < remote_device_count # check remote range - ]): + if not all( + (0 <= d.index < remote_device_count if d.type == "cuda" else True) + for d in val_set + ): raise ValueError( f"Invalid device_map configuration on {name} " f"for {remote_name}, remote device out of range:\n" diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 2e8f0392dbf..0fb13e35de3 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -9,19 +9,19 @@ from typing import Dict, List, Optional, Union DeviceType = Union[int, str, torch.device] -def _to_device_index(device: DeviceType) -> int: +def _to_device(device: DeviceType) -> torch.device: device = torch.device(device) if device.type != "cuda": raise ValueError( "`set_devices` expect a list of CUDA devices, but got " f"device type {device.type}." ) - return device.index + return device -def _to_device_index_map(device_map: Dict[DeviceType, DeviceType]) -> Dict[int, int]: - device_index_map : Dict[int, int] = {} - reverse_map : Dict[int, int] = {} +def _to_device_map(device_map: Dict[DeviceType, DeviceType]) -> Dict[torch.device, torch.device]: + full_device_map : Dict[torch.device, torch.device] = {} + reverse_map : Dict[torch.device, torch.device] = {} for k in device_map: v = device_map[k] k, v = torch.device(k), torch.device(v) @@ -31,18 +31,18 @@ def _to_device_index_map(device_map: Dict[DeviceType, DeviceType]) -> Dict[int, f"but got device pair {k}: {v}" ) - if v.index in reverse_map: + if v in reverse_map: raise ValueError( "`device_map` only supports 1-to-1 mapping, " - f"trying to map {k} and {reverse_map[v.index]} to {v.index}" + f"trying to map {k} and {reverse_map[v]} to {v}" ) - device_index_map[k.index] = v.index - reverse_map[v.index] = k.index - return device_index_map + full_device_map[k] = v + reverse_map[v] = k + return full_device_map -def _to_device_index_list(devices: List[DeviceType]) -> List[int]: - return list(map(_to_device_index, devices)) +def _to_device_list(devices: List[DeviceType]) -> List[torch.device]: + return list(map(_to_device, devices)) @@ -90,13 +90,13 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): _transports: List = None, _channels: List = None, ): - device_index_maps = ( + full_device_maps = ( {} if device_maps is None else - {k : _to_device_index_map(v) for k, v in device_maps.items()} + {k : _to_device_map(v) for k, v in device_maps.items()} ) - device_index_list = ( + full_device_list = ( [] if devices is None else - _to_device_index_list(devices) + _to_device_list(devices) ) super().__init__( num_worker_threads, @@ -104,8 +104,8 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): _channels, rpc_timeout, init_method, - device_index_maps, - device_index_list, + full_device_maps, + full_device_list, ) def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]): @@ -152,18 +152,18 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): >>> print(rets[0]) # tensor([2., 2.], device='cuda:0') >>> print(rets[1]) # tensor([2., 2.], device='cuda:1') """ - device_index_map = _to_device_index_map(device_map) + full_device_map = _to_device_map(device_map) curr_device_maps = super().device_maps if to in curr_device_maps: - for k, v in device_index_map.items(): + for k, v in full_device_map.items(): if k in curr_device_maps[to] and v != curr_device_maps[to][k]: raise ValueError( "`set_device_map` only supports 1-to-1 mapping, trying" f" to map {k} to {v} and {curr_device_maps[to][k]}" ) - super()._set_device_map(to, device_index_map) + super()._set_device_map(to, full_device_map) def set_devices(self, devices: List[DeviceType]): r""" @@ -175,4 +175,4 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): devices (List of int, str, or torch.device): local devices used by the TensorPipe RPC agent. """ - self.devices = _to_device_index_list(devices) + self.devices = _to_device_list(devices)