pytorch/torch/csrc/autograd/python_saved_variable_hooks.h
Victor Quach ee5a97de11 Register Saved Tensors hooks (#60663)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60663

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D29466223

fbshipit-source-id: 65dc3a935c18a0e6b93a37e24543c696e6ae0321
2021-07-15 08:09:55 -07:00

25 lines
643 B
C++

#pragma once
#include <pybind11/pybind11.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/THP_export.h>
#include <ATen/ATen.h>
namespace py = pybind11;
namespace torch { namespace autograd {
struct PySavedVariableHooks : public SavedVariableHooks {
PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook);
void call_pack_hook(at::Tensor &tensor) override;
at::Tensor call_unpack_hook() override;
~PySavedVariableHooks() override;
private:
PyObject* pack_hook_;
PyObject* unpack_hook_;
};
}}