Add 'share_from_this' to 'torch::jit::Graph' (#87343)

Avoid passing raw pointer of 'torch::jit::Graph' to python. Otherwise, it will corrupt the
`internals::registered_instance` of pybind11, caching a holder for python w.r.t the raw
pointer of 'torch::jit::Graph', while not increasing the use count of the existing shared_ptr.

The behavior afterwards is random and probably undefined.
Most of the time it works, if the holder is deallocated timely on python side, and the
cache then cleared from `internals::registered_instance`. Things are back to normal.
Otherwise, it fails with either segfault or a runtime error of message "Unable to cast
from non-held to held instance". One of such scenarios is normally and correctly
returning a shared_ptr of that 'torch::jit::Graph' to python. Pybind finds the holder via
cache. Due to this, the shared_ptr use_count will not increase. If there is no other use
on C++ side, the graph will be freed, while python still has access, via the holder created
previously.

@t-vi had a great analysis and solution to this exact problem at #51833 which I hope
I had seen before debugging this issue... ~~I'm building the PR based on the original
commit. @t-vi please let me know if you'd prefer otherwise.~~ Sending the PR separately
due to CLA issues.

Need to check in CI if adding `enable_shared_from_this` breaks other stuff.

Fixes #51833, and CI issues in #87258, #86182.

cc @malfet, @kit1980 for changes on JIT IR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87343
Approved by: https://github.com/justinchuby, https://github.com/AllenTiTaiWang, https://github.com/malfet
This commit is contained in:
BowenBao 2022-10-28 23:51:42 +00:00 committed by PyTorch MergeBot
parent ecf277abec
commit 376acf7625
3 changed files with 34 additions and 7 deletions

View file

@ -38,9 +38,7 @@ class TestCustomOps(common_utils.TestCase):
def symbolic_custom_add(g, self, other):
return g.op("Add", self, other)
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic(
torch.onnx.register_custom_op_symbolic(
"custom_namespace::custom_add", symbolic_custom_add, 9
)
@ -48,6 +46,9 @@ class TestCustomOps(common_utils.TestCase):
y = torch.randn(2, 3, 4, requires_grad=False)
model = CustomAddModel()
# before fixing #51833 this used to give a PyBind error
# with PyTorch 1.10dev ("Unable to cast from non-held to held
# instance (T& to Holder<T>)")
onnxir, _ = do_export(model, (x, y), opset_version=11)
onnx_model = onnx.ModelProto.FromString(onnxir)
prepared = c2.prepare(onnx_model)

View file

@ -239,6 +239,11 @@ struct Value {
const Node* node() const {
return node_;
}
/**
* @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
* Check #87343 for details.
*/
Graph* owningGraph();
const Graph* owningGraph() const;
// TODO: make this more const correct
@ -398,6 +403,10 @@ struct TORCH_API Node {
}
SourceRange sourceRange() const;
/**
* @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
* Check #87343 for details.
*/
Graph* owningGraph() {
return graph_;
}
@ -1049,6 +1058,10 @@ struct Block {
const Node* param_node() const {
return input_;
}
/**
* @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
* Check #87343 for details.
*/
Graph* owningGraph() {
return graph_;
}
@ -1163,7 +1176,7 @@ struct Block {
std::shared_ptr<Wrap<Block>> wrap_;
};
struct Graph {
struct Graph : std::enable_shared_from_this<Graph> {
AT_DISALLOW_COPY_AND_ASSIGN(Graph);
friend struct Node;
friend struct Value;

View file

@ -426,8 +426,16 @@ void NodeToONNX(
WithInsertPoint insert_point_guard(new_block);
WithCurrentScope scope_guard(*g, n->scope());
// IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
// Python. Check #87343 for details.
py::object raw_output = onnx.attr("_run_symbolic_function")(
g, new_block, n, py_inputs, env, operator_export_type);
g->shared_from_this(),
new_block,
n,
py_inputs,
env,
operator_export_type);
// Find new nodes that have been created by _run_symbolic_function and
// propagate metadata
@ -530,8 +538,11 @@ void NodeToONNX(
opset_version,
pyobj.attr("symbolic"),
/* custom */ true);
// IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
// Python. Check #87343 for details.
py::object raw_output = onnx.attr("_run_symbolic_method")(
new_block->owningGraph(),
new_block->owningGraph()->shared_from_this(),
op->name(),
pyobj.attr("symbolic"),
py_symbolic_args);
@ -542,8 +553,10 @@ void NodeToONNX(
Node* n = static_cast<Node*>(op);
n->s_(attr::name, op->name());
// Call symbolic function
// IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
// Python. Check #87343 for details.
py::object raw_output = onnx.attr("_run_symbolic_function")(
new_block->owningGraph(),
new_block->owningGraph()->shared_from_this(),
new_block,
n,
py_symbolic_args,