mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Speed up fx graph iteration by implementing it in C++ (#128288)
Before this change ``` python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s) ``` After this change ``` python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s) ``` 5.7x improvement Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288 Approved by: https://github.com/jansel, https://github.com/albanD
This commit is contained in:
parent
fa88f390a0
commit
5b5d269d34
7 changed files with 300 additions and 18 deletions
|
|
@ -827,6 +827,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/mtia/Module.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
|
|
|
|||
|
|
@ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ...
|
|||
# Defined in torch/csrc/jit/runtime/static/init.cpp
|
||||
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
|
||||
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/fx/node.cpp
|
||||
class _NodeBase:
|
||||
_erased: _bool
|
||||
_prev: "_NodeBase"
|
||||
_next: "_NodeBase"
|
||||
|
||||
class _NodeIter(Iterator):
|
||||
def __init__(self, root: _NodeBase, reversed: _bool) -> None: ...
|
||||
def __iter__(self) -> Iterator[_NodeBase]: ...
|
||||
def __next__(self) -> _NodeBase: ...
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@
|
|||
#include <torch/csrc/cpu/Module.h>
|
||||
#include <torch/csrc/dynamo/init.h>
|
||||
#include <torch/csrc/functorch/init.h>
|
||||
#include <torch/csrc/fx/node.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/pybind.h>
|
||||
#include <torch/csrc/jit/python/init.h>
|
||||
#include <torch/csrc/jit/python/python_ir.h>
|
||||
|
|
@ -1602,6 +1603,8 @@ PyObject* initModule() {
|
|||
THPDevice_init(module);
|
||||
THPStream_init(module);
|
||||
THPEvent_init(module);
|
||||
NodeBase_init(module);
|
||||
NodeIter_init(module);
|
||||
ASSERT_TRUE(THPVariable_initModule(module));
|
||||
ASSERT_TRUE(THPFunction_initModule(module));
|
||||
ASSERT_TRUE(THPEngine_initModule(module));
|
||||
|
|
|
|||
257
torch/csrc/fx/node.cpp
Normal file
257
torch/csrc/fx/node.cpp
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
////////////////////////////////
|
||||
// NodeBase
|
||||
///////////////////////////////
|
||||
|
||||
struct NodeBase {
|
||||
PyObject_HEAD bool _erased;
|
||||
NodeBase* _prev;
|
||||
NodeBase* _next;
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwds) {
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
return self;
|
||||
}
|
||||
|
||||
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
self->_erased = false;
|
||||
Py_INCREF(self);
|
||||
self->_prev = self;
|
||||
Py_INCREF(self);
|
||||
self->_next = self;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static struct PyMemberDef NodeBase_members[] = {
|
||||
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
|
||||
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
|
||||
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->_prev);
|
||||
Py_VISIT(self->_next);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->_prev);
|
||||
Py_CLEAR(self->_next);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeBase_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
(void)NodeBase_clear((NodeBase*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
|
||||
sizeof(NodeBase), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)NodeBase_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
||||
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
(traverseproc)NodeBase_traverse, /* tp_traverse */
|
||||
(inquiry)NodeBase_clear, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)NodeBase_init_fn, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
NodeBase_new, /* tp_new */
|
||||
};
|
||||
|
||||
bool NodeBase_init(PyObject* module) {
|
||||
if (PyModule_AddType(module, &NodeBaseType) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////
|
||||
// NodeIter
|
||||
////////////////////////////////
|
||||
|
||||
struct NodeIter {
|
||||
PyObject_HEAD bool _reversed;
|
||||
NodeBase* _root;
|
||||
NodeBase* _cur;
|
||||
};
|
||||
|
||||
static PyObject* NodeIter_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwds) {
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
return self;
|
||||
}
|
||||
|
||||
static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
|
||||
NodeBase* root = nullptr;
|
||||
bool reversed = false;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
constexpr const char* keywords[] = {"root", "reversed", nullptr};
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args,
|
||||
kwargs,
|
||||
"Ob|",
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<char**>(keywords),
|
||||
&root,
|
||||
&reversed)) {
|
||||
return -1;
|
||||
}
|
||||
self->_reversed = reversed;
|
||||
Py_INCREF(root);
|
||||
self->_root = root;
|
||||
Py_INCREF(root);
|
||||
self->_cur = root;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <bool reversed>
|
||||
PyObject* NodeIter_iternext_helper(NodeIter* self) {
|
||||
// It should be possible to relax the ref counting here
|
||||
// but in practice, we do not have that many _erased Nodes,
|
||||
// so probably not worth it.
|
||||
if constexpr (reversed) {
|
||||
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = prev;
|
||||
} else {
|
||||
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = next;
|
||||
}
|
||||
while (self->_cur != self->_root) {
|
||||
if (!self->_cur->_erased) {
|
||||
Py_INCREF(self->_cur);
|
||||
return (PyObject*)self->_cur;
|
||||
}
|
||||
if constexpr (reversed) {
|
||||
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = prev;
|
||||
} else {
|
||||
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = next;
|
||||
}
|
||||
}
|
||||
PyErr_SetNone(PyExc_StopIteration);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* NodeIter_iternext(PyObject* _self) {
|
||||
NodeIter* self = (NodeIter*)_self;
|
||||
if (self->_reversed) {
|
||||
return NodeIter_iternext_helper<true>(self);
|
||||
} else {
|
||||
return NodeIter_iternext_helper<false>(self);
|
||||
}
|
||||
}
|
||||
|
||||
static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->_root);
|
||||
Py_VISIT(self->_cur);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NodeIter_clear(NodeIter* self) {
|
||||
Py_CLEAR(self->_root);
|
||||
Py_CLEAR(self->_cur);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeIter_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
(void)NodeIter_clear((NodeIter*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static PyTypeObject NodeIterType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
|
||||
sizeof(NodeIter), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)NodeIter_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
(traverseproc)NodeIter_traverse, /* tp_traverse */
|
||||
(inquiry)NodeIter_clear, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
PyObject_SelfIter, /* tp_iter */
|
||||
NodeIter_iternext, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)NodeIter_init_fn, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
NodeIter_new, /* tp_new */
|
||||
};
|
||||
|
||||
bool NodeIter_init(PyObject* module) {
|
||||
if (PyModule_AddType(module, &NodeIterType) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
6
torch/csrc/fx/node.h
Normal file
6
torch/csrc/fx/node.h
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
bool NodeBase_init(PyObject* module);
|
||||
bool NodeIter_init(PyObject* module);
|
||||
|
|
@ -4,6 +4,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na
|
|||
import torch.utils._pytree as pytree
|
||||
from . import _pytree as fx_pytree
|
||||
from ._compatibility import compatibility
|
||||
from torch._C import _NodeIter
|
||||
|
||||
import os
|
||||
import contextlib
|
||||
|
|
@ -271,20 +272,8 @@ class _node_list:
|
|||
return self.graph._len
|
||||
|
||||
def __iter__(self):
|
||||
root = self.graph._root
|
||||
if self.direction == "_next":
|
||||
cur = root._next
|
||||
while cur is not root:
|
||||
if not cur._erased:
|
||||
yield cur
|
||||
cur = cur._next
|
||||
else:
|
||||
assert self.direction == "_prev"
|
||||
cur = root._prev
|
||||
while cur is not root:
|
||||
if not cur._erased:
|
||||
yield cur
|
||||
cur = cur._prev
|
||||
assert self.direction == "_prev" or self.direction == "_next"
|
||||
yield from _NodeIter(self.graph._root, self.direction == "_prev")
|
||||
|
||||
def __reversed__(self):
|
||||
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import inspect
|
|||
import warnings
|
||||
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
|
||||
from .._ops import ops as _ops
|
||||
from torch._C import _NodeBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
|
@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str:
|
|||
return str(arg)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Node:
|
||||
class Node(_NodeBase):
|
||||
"""
|
||||
``Node`` is the data structure that represents individual operations within
|
||||
a ``Graph``. For the most part, Nodes represent callsites to various entities,
|
||||
|
|
@ -197,6 +198,7 @@ class Node:
|
|||
annotation of values in the generated code or for other types
|
||||
of analyses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.graph = graph
|
||||
self.name = name # unique name of value being created
|
||||
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
|
||||
|
|
@ -235,9 +237,6 @@ class Node:
|
|||
# does not produce a value, it's more of a notation. Thus, this value
|
||||
# describes the type of args[0] in the ``return`` node.
|
||||
self.type : Optional[Any] = return_type
|
||||
self._prev = self
|
||||
self._next = self
|
||||
self._erased = False
|
||||
self._sort_key: Any = ()
|
||||
|
||||
# If set, use this fn to print this node
|
||||
|
|
@ -247,6 +246,22 @@ class Node:
|
|||
# transformations. This metadata is preserved across node copies
|
||||
self.meta : Dict[str, Any] = {}
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["_erased"] = self._erased
|
||||
state["_prev"] = self._prev
|
||||
state["_next"] = self._next
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
_erased = state.pop("_erased")
|
||||
_prev = state.pop("_prev")
|
||||
_next = state.pop("_next")
|
||||
self.__dict__.update(state)
|
||||
self._erased = _erased
|
||||
self._prev = _prev
|
||||
self._next = _next
|
||||
|
||||
@property
|
||||
def next(self) -> 'Node':
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue