mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60663 Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D29466223 fbshipit-source-id: 65dc3a935c18a0e6b93a37e24543c696e6ae0321
25 lines
643 B
C++
25 lines
643 B
C++
#pragma once
|
|
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/THP_export.h>
|
|
#include <ATen/ATen.h>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
struct PySavedVariableHooks : public SavedVariableHooks {
|
|
PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook);
|
|
void call_pack_hook(at::Tensor &tensor) override;
|
|
at::Tensor call_unpack_hook() override;
|
|
~PySavedVariableHooks() override;
|
|
|
|
private:
|
|
PyObject* pack_hook_;
|
|
PyObject* unpack_hook_;
|
|
};
|
|
|
|
}}
|