Optimize torch.ops.ns.opname.overload accessor in torch dispatch (#85132)

This doesn't actually seem to help all that much.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85132
Approved by: https://github.com/wconstab
This commit is contained in:
Edward Z. Yang 2022-09-15 21:48:59 -07:00 committed by PyTorch MergeBot
parent 607eccb13c
commit e5fac7f5dc
6 changed files with 126 additions and 47 deletions

View file

@ -398,6 +398,11 @@ public:
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
}
template <typename F>
PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const {
return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
}
private:
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}

View file

@ -6,6 +6,7 @@
#include <c10/util/either.h>
#include <c10/util/Optional.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/PyHandleCache.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
@ -211,6 +212,11 @@ public:
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;
template <typename F>
PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
return py_cache_.ptr_or(self_interpreter, slow_accessor);
}
private:
OperatorName name_;
@ -220,6 +226,8 @@ private:
#endif
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
// Pointer to the torch.ops.ns.op.overload object for speed
c10::PyHandleCache py_cache_;
// kernels_ stores all registered kernels for the corresponding dispatch key
// and catchAllKernels_ stores the catch-all kernels.

75
c10/core/PyHandleCache.h Normal file
View file

@ -0,0 +1,75 @@
#pragma once
#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Macros.h>
#include <c10/util/python_stub.h>
#include <atomic>
namespace c10 {
// A PyHandleCache represents a cached pointer from a C++ object to
// a Python object that represents that object analogously in Python.
// Upon a cache hit, the relevant object can be retrieved after a test
// and then a memory load. Two conditions must hold to be able to use this
// class:
//
// - This must truly be a cache; e.g., the caller must be able to produce
// the object some other way if the cache hit misses.
//
// - This must truly be a handle; e.g., the Python object referenced by
// this class must have static lifetime. This means we don't have to
// maintain strong ownership or deallocate the object when the C++ object
// dies. Static lifetime is a good idea in conjunction with the cache,
// since if you are producing a fresh object on miss you won't be
// maintaining object identity. If you need bidirectional ownership,
// you will want to factor out the pattern in TensorImpl with
// resurrection.
//
// This cache is expected to not improve perf under torchdeploy, as one
// interpreter will fill up the cache, and all the interpreters will be
// unable to use the slot. A potential improvement is to have multiple
// slots (one per interpreter), which will work in deployment scenarios
// where there a stable, fixed number of interpreters. You can also store
// the relevant state in the Python library, rather than in the non-Python
// library (although in many cases, this is not convenient, as there may
// not be a way to conveniently index based on the object.)
class PyHandleCache {
public:
PyHandleCache() : pyinterpreter_(nullptr), data_(nullptr) {}
// Attempt to fetch the pointer from the cache, if the PyInterpreter
// matches. If it doesn't exist, or the cache entry is not valid,
// use slow_accessor to get the real pointer value and return that
// (possibly writing it to the cache, if the cache entry is
// available.)
template <typename F>
PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor)
const {
// Note [Memory ordering on Python interpreter tag]
impl::PyInterpreter* interpreter =
pyinterpreter_.load(std::memory_order_acquire);
if (C10_LIKELY(interpreter == self_interpreter)) {
return data_;
} else if (interpreter == nullptr) {
auto* r = slow_accessor();
impl::PyInterpreter* expected = nullptr;
// attempt to claim this cache entry with the specified interpreter tag
if (pyinterpreter_.compare_exchange_strong(
expected, self_interpreter, std::memory_order_acq_rel)) {
data_ = r;
}
// This shouldn't be possible, as you should be GIL protected
TORCH_INTERNAL_ASSERT(expected != self_interpreter);
return r;
} else {
return slow_accessor();
}
}
private:
mutable std::atomic<impl::PyInterpreter*> pyinterpreter_;
mutable PyObject* data_;
};
} // namespace c10

View file

@ -242,7 +242,7 @@ class TestReductions(TestCase):
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m)
with self.assertRaisesRegex(TypeError, "no implementation found for 'torch.ops.aten.max'"):
with self.assertRaisesRegex(TypeError, "no implementation found for 'torch._ops.aten.max.default'"):
mt.max()
def test_sum(self):

View file

@ -2195,6 +2195,30 @@ py::object torchDispatchFromTensorImpl(
TorchFunctionName::TorchDispatch));
}
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name == "") {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
void ConcretePyInterpreterVTable::dispatch(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) const {
@ -2202,17 +2226,6 @@ void ConcretePyInterpreterVTable::dispatch(
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
// The plan: convert all the arguments back into PyObjects,
// extracting out the tensor handles, then call
// handle_torch_function_no_python_arg_parser
@ -2222,16 +2235,7 @@ void ConcretePyInterpreterVTable::dispatch(
py::gil_scoped_acquire g;
std::vector<py::handle> overloaded_args;
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
py::handle torch_api_function_overload;
if (overload_name == "") {
torch_api_function_overload = torch_api_function.attr("default");
} else {
torch_api_function_overload =
torch_api_function.attr(overload_name.c_str());
}
std::string module_name_str = "torch.ops." + ns_str;
py::handle torch_api_function_overload = getTorchApiFunction(op);
// Find overloaded tensors
for (const auto idx : c10::irange(arguments.size())) {
@ -2263,9 +2267,9 @@ void ConcretePyInterpreterVTable::dispatch(
overloaded_args,
args.ptr(),
kwargs.ptr(),
func_name,
nullptr,
torch_api_function_overload.ptr(),
module_name_str.c_str(),
nullptr,
TorchFunctionName::TorchDispatch);
pushPyOutToStack(
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
@ -2279,17 +2283,6 @@ void ConcretePyInterpreterVTable::python_dispatcher(
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
// The plan: convert all the arguments back into PyObjects,
// extracting out the tensor handles, then call
// handle_torch_function_no_python_arg_parser
@ -2299,16 +2292,7 @@ void ConcretePyInterpreterVTable::python_dispatcher(
py::gil_scoped_acquire g;
std::vector<py::handle> overloaded_args;
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
py::handle torch_api_function_overload;
if (overload_name == "") {
torch_api_function_overload = torch_api_function.attr("default");
} else {
torch_api_function_overload =
torch_api_function.attr(overload_name.c_str());
}
std::string module_name_str = "torch.ops." + ns_str;
py::handle torch_api_function_overload = getTorchApiFunction(op);
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);

View file

@ -377,8 +377,15 @@ auto handle_torch_function_no_python_arg_parser(
// all __torch_function__ implementations in overloaded_args
// returned NotImplemented, so we raise a TypeError.
std::stringstream ss;
ss << "no implementation found for '" << module_name << "." << func_name
<< "' on types that implement " << torch_function_name_str << ": [";
ss << "no implementation found for '";
if (module_name && func_name) {
ss << module_name << "." << func_name;
} else {
py::handle fn = torch_api_function;
ss << py::str(fn.attr("__module__")) << "."
<< py::str(fn.attr("__name__"));
}
ss << "' on types that implement " << torch_function_name_str << ": [";
for (auto& arg : overloaded_args) {
ss << py::repr(get_type_of_overloaded_arg(arg.ptr()));
if (!arg.is(overloaded_args.back())) {