#include #include #include namespace torch { namespace distributed { namespace rpc { ScriptRemoteCall::ScriptRemoteCall( std::shared_ptr op, std::vector&& stack, const RRefId& retRRefId, const ForkId& retForkId) : ScriptCall(std::move(op), std::move(stack)), retRRefId_(retRRefId), retForkId_(retForkId) {} ScriptRemoteCall::ScriptRemoteCall( const c10::QualifiedName& qualifiedName, std::vector&& stack, const RRefId& retRRefId, const ForkId& retForkId, const bool isAsyncExecution) : ScriptCall(qualifiedName, std::move(stack), isAsyncExecution), retRRefId_(retRRefId), retForkId_(retForkId) {} std::unique_ptr ScriptRemoteCall::fromIValues( std::vector& ivalues) { // remove the last element from values and convert it back to an RRef auto retForkId = RRefId::fromIValue(ivalues.back()); ivalues.pop_back(); auto retRRefId = ForkId::fromIValue(ivalues.back()); ivalues.pop_back(); auto scriptCallPtr = ScriptCall::fromIValues(ivalues); if (scriptCallPtr->hasOp()) { return std::make_unique( scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId); } else { return std::make_unique( scriptCallPtr->qualifiedName(), std::move(ivalues), retRRefId, retForkId, scriptCallPtr->isAsyncExecution()); } } c10::intrusive_ptr ScriptRemoteCall::toMessageImpl() && { std::vector ivalues; ScriptCall::toIValues(ivalues); ivalues.emplace_back(retRRefId_.toIValue()); ivalues.emplace_back(retForkId_.toIValue()); std::vector tensor_table; auto payload = jit::pickle( c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table); return c10::make_intrusive( std::move(payload), std::move(tensor_table), MessageType::SCRIPT_REMOTE_CALL); } std::unique_ptr ScriptRemoteCall::fromMessage( const Message& message) { auto payload = static_cast(message.payload().data()); auto payload_size = message.payload().size(); auto value = jit::unpickle( payload, payload_size, *RpcAgent::getCurrentRpcAgent()->getTypeResolver(), message.tensors()); auto values = value.toTupleRef().elements().vec(); TORCH_CHECK(!values.empty(), "Malformed message: empty values unpickled"); return fromIValues(values); } } // namespace rpc } // namespace distributed } // namespace torch