mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62074 Since SPMD mode is retired, the comm hook result will always be a single tensor. This can improve comm hook developer experience, as no need to add an extra `[0]` to the precursor future result. #Closes: https://github.com/pytorch/pytorch/issues/61914 ghstack-source-id: 134164593 Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork Reviewed By: rohan-varma Differential Revision: D29864732 fbshipit-source-id: 59fe6dd78b66214b1788514ad4d236039d9bda31
34 lines
1 KiB
C++
34 lines
1 KiB
C++
#pragma once
|
|
|
|
#include <c10d/comm.hpp>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace c10d {
|
|
|
|
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
|
|
public:
|
|
// Takes a state and a callable hook. The inputs are Python objects.
|
|
// The state is passed to the hook in runHook method, and it can be used to
|
|
// maintain and update any state information during the execution of the hook.
|
|
// The hook performs user-specified processing and returns a future indicating
|
|
// asychronous communication of gradients.
|
|
PythonCommHook(py::object state, py::object hook)
|
|
: state_(std::move(state)), hook_(std::move(hook)) {}
|
|
|
|
~PythonCommHook() override;
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
|
|
|
|
at::Tensor parseHookResult(const c10::IValue& result) override;
|
|
|
|
private:
|
|
// Only needed for stateful communication.
|
|
py::object state_;
|
|
py::object hook_;
|
|
};
|
|
|
|
} // namespace c10d
|