pytorch/torch/csrc/distributed/c10d/python_comm_hook.h
Yi Wang b03b45afd9 [DDP Comm Hook] Use a single tensor instead of a tensor list as the comm hook result (#62074)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62074

Since SPMD mode is retired, the comm hook result will always be a single tensor.

This can improve comm hook developer experience, as no need to add an extra `[0]` to the precursor future result.

#Closes: https://github.com/pytorch/pytorch/issues/61914
ghstack-source-id: 134164593

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork

Reviewed By: rohan-varma

Differential Revision: D29864732

fbshipit-source-id: 59fe6dd78b66214b1788514ad4d236039d9bda31
2021-07-23 03:32:05 -07:00

34 lines
1 KiB
C++

#pragma once
#include <c10d/comm.hpp>
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
namespace c10d {
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
public:
// Takes a state and a callable hook. The inputs are Python objects.
// The state is passed to the hook in runHook method, and it can be used to
// maintain and update any state information during the execution of the hook.
// The hook performs user-specified processing and returns a future indicating
// asychronous communication of gradients.
PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)) {}
~PythonCommHook() override;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
at::Tensor parseHookResult(const c10::IValue& result) override;
private:
// Only needed for stateful communication.
py::object state_;
py::object hook_;
};
} // namespace c10d