mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Optimize torch.ops.ns.opname.overload accessor in torch dispatch (#85132)
This doesn't actually seem to help all that much. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/85132 Approved by: https://github.com/wconstab
This commit is contained in:
parent
607eccb13c
commit
e5fac7f5dc
6 changed files with 126 additions and 47 deletions
|
|
@ -398,6 +398,11 @@ public:
|
|||
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const {
|
||||
return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
|
||||
}
|
||||
|
||||
private:
|
||||
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
|
||||
: operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <c10/util/either.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/core/PyHandleCache.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/boxing/KernelFunction.h>
|
||||
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
|
||||
|
|
@ -211,6 +212,11 @@ public:
|
|||
// Returns all the operator tags added at the time of registration
|
||||
const std::vector<at::Tag>& getTags() const;
|
||||
|
||||
template <typename F>
|
||||
PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
|
||||
return py_cache_.ptr_or(self_interpreter, slow_accessor);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
OperatorName name_;
|
||||
|
|
@ -220,6 +226,8 @@ private:
|
|||
#endif
|
||||
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
|
||||
DispatchKeyExtractor dispatchKeyExtractor_;
|
||||
// Pointer to the torch.ops.ns.op.overload object for speed
|
||||
c10::PyHandleCache py_cache_;
|
||||
|
||||
// kernels_ stores all registered kernels for the corresponding dispatch key
|
||||
// and catchAllKernels_ stores the catch-all kernels.
|
||||
|
|
|
|||
75
c10/core/PyHandleCache.h
Normal file
75
c10/core/PyHandleCache.h
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/impl/PyInterpreter.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// A PyHandleCache represents a cached pointer from a C++ object to
|
||||
// a Python object that represents that object analogously in Python.
|
||||
// Upon a cache hit, the relevant object can be retrieved after a test
|
||||
// and then a memory load. Two conditions must hold to be able to use this
|
||||
// class:
|
||||
//
|
||||
// - This must truly be a cache; e.g., the caller must be able to produce
|
||||
// the object some other way if the cache hit misses.
|
||||
//
|
||||
// - This must truly be a handle; e.g., the Python object referenced by
|
||||
// this class must have static lifetime. This means we don't have to
|
||||
// maintain strong ownership or deallocate the object when the C++ object
|
||||
// dies. Static lifetime is a good idea in conjunction with the cache,
|
||||
// since if you are producing a fresh object on miss you won't be
|
||||
// maintaining object identity. If you need bidirectional ownership,
|
||||
// you will want to factor out the pattern in TensorImpl with
|
||||
// resurrection.
|
||||
//
|
||||
// This cache is expected to not improve perf under torchdeploy, as one
|
||||
// interpreter will fill up the cache, and all the interpreters will be
|
||||
// unable to use the slot. A potential improvement is to have multiple
|
||||
// slots (one per interpreter), which will work in deployment scenarios
|
||||
// where there a stable, fixed number of interpreters. You can also store
|
||||
// the relevant state in the Python library, rather than in the non-Python
|
||||
// library (although in many cases, this is not convenient, as there may
|
||||
// not be a way to conveniently index based on the object.)
|
||||
class PyHandleCache {
|
||||
public:
|
||||
PyHandleCache() : pyinterpreter_(nullptr), data_(nullptr) {}
|
||||
|
||||
// Attempt to fetch the pointer from the cache, if the PyInterpreter
|
||||
// matches. If it doesn't exist, or the cache entry is not valid,
|
||||
// use slow_accessor to get the real pointer value and return that
|
||||
// (possibly writing it to the cache, if the cache entry is
|
||||
// available.)
|
||||
template <typename F>
|
||||
PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor)
|
||||
const {
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyinterpreter_.load(std::memory_order_acquire);
|
||||
if (C10_LIKELY(interpreter == self_interpreter)) {
|
||||
return data_;
|
||||
} else if (interpreter == nullptr) {
|
||||
auto* r = slow_accessor();
|
||||
impl::PyInterpreter* expected = nullptr;
|
||||
// attempt to claim this cache entry with the specified interpreter tag
|
||||
if (pyinterpreter_.compare_exchange_strong(
|
||||
expected, self_interpreter, std::memory_order_acq_rel)) {
|
||||
data_ = r;
|
||||
}
|
||||
// This shouldn't be possible, as you should be GIL protected
|
||||
TORCH_INTERNAL_ASSERT(expected != self_interpreter);
|
||||
return r;
|
||||
} else {
|
||||
return slow_accessor();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::atomic<impl::PyInterpreter*> pyinterpreter_;
|
||||
mutable PyObject* data_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
@ -242,7 +242,7 @@ class TestReductions(TestCase):
|
|||
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
||||
m = torch.tensor([[True, False, False], [False, True, False]])
|
||||
mt = MaskedTensor(d, m)
|
||||
with self.assertRaisesRegex(TypeError, "no implementation found for 'torch.ops.aten.max'"):
|
||||
with self.assertRaisesRegex(TypeError, "no implementation found for 'torch._ops.aten.max.default'"):
|
||||
mt.max()
|
||||
|
||||
def test_sum(self):
|
||||
|
|
|
|||
|
|
@ -2195,6 +2195,30 @@ py::object torchDispatchFromTensorImpl(
|
|||
TorchFunctionName::TorchDispatch));
|
||||
}
|
||||
|
||||
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
|
||||
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& schema = op.schema();
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
if (overload_name == "") {
|
||||
return torch_api_function.attr("default").ptr();
|
||||
} else {
|
||||
return torch_api_function.attr(overload_name.c_str()).ptr();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ConcretePyInterpreterVTable::dispatch(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) const {
|
||||
|
|
@ -2202,17 +2226,6 @@ void ConcretePyInterpreterVTable::dispatch(
|
|||
const auto num_arguments = schema.arguments().size();
|
||||
auto arguments = torch::jit::pop(*stack, num_arguments);
|
||||
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
// The plan: convert all the arguments back into PyObjects,
|
||||
// extracting out the tensor handles, then call
|
||||
// handle_torch_function_no_python_arg_parser
|
||||
|
|
@ -2222,16 +2235,7 @@ void ConcretePyInterpreterVTable::dispatch(
|
|||
py::gil_scoped_acquire g;
|
||||
|
||||
std::vector<py::handle> overloaded_args;
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
py::handle torch_api_function_overload;
|
||||
if (overload_name == "") {
|
||||
torch_api_function_overload = torch_api_function.attr("default");
|
||||
} else {
|
||||
torch_api_function_overload =
|
||||
torch_api_function.attr(overload_name.c_str());
|
||||
}
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
py::handle torch_api_function_overload = getTorchApiFunction(op);
|
||||
|
||||
// Find overloaded tensors
|
||||
for (const auto idx : c10::irange(arguments.size())) {
|
||||
|
|
@ -2263,9 +2267,9 @@ void ConcretePyInterpreterVTable::dispatch(
|
|||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
func_name,
|
||||
nullptr,
|
||||
torch_api_function_overload.ptr(),
|
||||
module_name_str.c_str(),
|
||||
nullptr,
|
||||
TorchFunctionName::TorchDispatch);
|
||||
pushPyOutToStack(
|
||||
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
|
||||
|
|
@ -2279,17 +2283,6 @@ void ConcretePyInterpreterVTable::python_dispatcher(
|
|||
const auto num_arguments = schema.arguments().size();
|
||||
auto arguments = torch::jit::pop(*stack, num_arguments);
|
||||
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
// The plan: convert all the arguments back into PyObjects,
|
||||
// extracting out the tensor handles, then call
|
||||
// handle_torch_function_no_python_arg_parser
|
||||
|
|
@ -2299,16 +2292,7 @@ void ConcretePyInterpreterVTable::python_dispatcher(
|
|||
py::gil_scoped_acquire g;
|
||||
|
||||
std::vector<py::handle> overloaded_args;
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
py::handle torch_api_function_overload;
|
||||
if (overload_name == "") {
|
||||
torch_api_function_overload = torch_api_function.attr("default");
|
||||
} else {
|
||||
torch_api_function_overload =
|
||||
torch_api_function.attr(overload_name.c_str());
|
||||
}
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
py::handle torch_api_function_overload = getTorchApiFunction(op);
|
||||
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto args = std::move(args_kwargs.first);
|
||||
|
|
|
|||
|
|
@ -377,8 +377,15 @@ auto handle_torch_function_no_python_arg_parser(
|
|||
// all __torch_function__ implementations in overloaded_args
|
||||
// returned NotImplemented, so we raise a TypeError.
|
||||
std::stringstream ss;
|
||||
ss << "no implementation found for '" << module_name << "." << func_name
|
||||
<< "' on types that implement " << torch_function_name_str << ": [";
|
||||
ss << "no implementation found for '";
|
||||
if (module_name && func_name) {
|
||||
ss << module_name << "." << func_name;
|
||||
} else {
|
||||
py::handle fn = torch_api_function;
|
||||
ss << py::str(fn.attr("__module__")) << "."
|
||||
<< py::str(fn.attr("__name__"));
|
||||
}
|
||||
ss << "' on types that implement " << torch_function_name_str << ": [";
|
||||
for (auto& arg : overloaded_args) {
|
||||
ss << py::repr(get_type_of_overloaded_arg(arg.ptr()));
|
||||
if (!arg.is(overloaded_args.back())) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue