mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update qengine flag in python to string (#26620)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26620 This change updates torch.backend.quantized.engine to accept string ("fbgemm"/"qnnpack"/"none" for now). set_qengine and get_qengine return an int which represents the at::QEngine enum Test Plan: python test/test_torch.py Imported from OSS Differential Revision: D17533582 fbshipit-source-id: 5103263d0d59ff37d43dec27243cb76ba8ba633f
This commit is contained in:
parent
5d82cefa55
commit
45391ccecb
12 changed files with 47 additions and 175 deletions
|
|
@ -105,7 +105,7 @@ void Context::setQEngine(at::QEngine e) {
|
|||
quantized_engine = e;
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(false, "quantized engine ", toString(e), "is not supported");
|
||||
TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
|
||||
}
|
||||
|
||||
std::vector<at::QEngine> Context::supportedQEngines() const {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ namespace c10 {
|
|||
|
||||
/**
|
||||
* QEngine is an enum that is used to select the engine to run quantized ops.
|
||||
* Keep this enum in sync with get_qengine_id() in
|
||||
* torch/backends/quantized/__init__.py
|
||||
*/
|
||||
enum class QEngine : uint8_t {
|
||||
NoQEngine = 0,
|
||||
|
|
|
|||
|
|
@ -65,8 +65,12 @@ def _calculate_dynamic_qparams(X, dtype):
|
|||
|
||||
@contextmanager
|
||||
def enable_mobile_quantized_engine():
|
||||
torch.backends.quantized.engine = torch.qnnpack
|
||||
torch.backends.quantized.engine = 'qnnpack'
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.backends.quantized.engine = torch.fbgemm
|
||||
qengines = torch.backends.quantized.get_supported_qengines()
|
||||
if 'fbgemm' in qengines:
|
||||
torch.backends.quantized.engine = 'fbgemm'
|
||||
else:
|
||||
torch.backends.quantized.engine = 'none'
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from multiprocessing.reduction import ForkingPickler
|
|||
from common_device_type import instantiate_device_type_tests, \
|
||||
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
|
||||
dtypes, dtypesIfCUDA
|
||||
import torch.backends.quantized
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
@ -2114,13 +2115,12 @@ class _TestTorchMixin(object):
|
|||
test_inference(torch.float32)
|
||||
|
||||
def test_qengnie(self):
|
||||
qengines = torch._C._supported_qengines()
|
||||
# [TODO] Enable after the interface change
|
||||
# original_qe = torch._C._get_qengine()
|
||||
# for qe in qengines:
|
||||
# torch._C._set_qengine(qe)
|
||||
# assert torch._C._get_qengine() == qe, 'qengine not set successfully'
|
||||
# torch._C._set_qengine(original_qe)
|
||||
qengines = torch.backends.quantized.get_supported_qengines()
|
||||
original_qe = torch.backends.quantized.engine
|
||||
for qe in qengines:
|
||||
torch.backends.quantized.engine = qe
|
||||
assert torch.backends.quantized.engine == qe, 'qengine not set successfully'
|
||||
torch.backends.quantized.engine = original_qe
|
||||
|
||||
def test_new_tensor(self):
|
||||
expected = torch.autograd.Variable(torch.ByteTensor([1, 1]))
|
||||
|
|
|
|||
|
|
@ -230,7 +230,6 @@ def add_torch_libs():
|
|||
"torch/csrc/Generator.cpp",
|
||||
"torch/csrc/Layout.cpp",
|
||||
"torch/csrc/MemoryFormat.cpp",
|
||||
"torch/csrc/QEngine.cpp",
|
||||
"torch/csrc/QScheme.cpp",
|
||||
"torch/csrc/Module.cpp",
|
||||
"torch/csrc/PtrWrapper.cpp",
|
||||
|
|
@ -293,7 +292,6 @@ def add_torch_libs():
|
|||
"torch/csrc/utils/invalid_arguments.cpp",
|
||||
"torch/csrc/utils/object_ptr.cpp",
|
||||
"torch/csrc/utils/python_arg_parser.cpp",
|
||||
"torch/csrc/utils/qengines.cpp",
|
||||
"torch/csrc/utils/structseq.cpp",
|
||||
"torch/csrc/utils/tensor_apply.cpp",
|
||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ set(TORCH_PYTHON_SRCS
|
|||
${TORCH_SRC_DIR}/csrc/Layout.cpp
|
||||
${TORCH_SRC_DIR}/csrc/MemoryFormat.cpp
|
||||
${TORCH_SRC_DIR}/csrc/python_dimname.cpp
|
||||
${TORCH_SRC_DIR}/csrc/QEngine.cpp
|
||||
${TORCH_SRC_DIR}/csrc/QScheme.cpp
|
||||
${TORCH_SRC_DIR}/csrc/Module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/PtrWrapper.cpp
|
||||
|
|
@ -96,7 +95,6 @@ set(TORCH_PYTHON_SRCS
|
|||
${TORCH_SRC_DIR}/csrc/utils/invalid_arguments.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/qengines.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
|
||||
|
|
|
|||
|
|
@ -3,16 +3,40 @@ import sys
|
|||
import torch
|
||||
import types
|
||||
|
||||
# This function should correspond to the enums present in c10/core/QEngine.h
|
||||
def get_qengine_id(qengine):
|
||||
# type: (str) -> int
|
||||
if qengine == 'none':
|
||||
ret = 0
|
||||
elif qengine == 'fbgemm':
|
||||
ret = 1
|
||||
elif qengine == 'qnnpack':
|
||||
ret = 2
|
||||
else:
|
||||
ret = -1
|
||||
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
|
||||
return ret
|
||||
|
||||
# This function should correspond to the enums present in c10/core/QEngine.h
|
||||
def get_qengine_str(qengine):
|
||||
# type: (int) -> str
|
||||
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'}
|
||||
return all_engines.get(qengine)
|
||||
|
||||
def get_supported_qengines():
|
||||
qengines = torch._C._supported_qengines()
|
||||
return [get_qengine_str(qe) for qe in qengines]
|
||||
|
||||
class ContextProp(object):
|
||||
def __init__(self, getter, setter):
|
||||
self.getter = getter
|
||||
self.setter = setter
|
||||
|
||||
def __get__(self, obj, objtype):
|
||||
return self.getter()
|
||||
return get_qengine_str(self.getter())
|
||||
|
||||
def __set__(self, obj, val):
|
||||
self.setter(val)
|
||||
self.setter(get_qengine_id(val))
|
||||
|
||||
class QuantizedEngine(types.ModuleType):
|
||||
def __init__(self, m, name):
|
||||
|
|
@ -21,7 +45,6 @@ class QuantizedEngine(types.ModuleType):
|
|||
|
||||
def __getattr__(self, attr):
|
||||
return self.m.__getattribute__(attr)
|
||||
# TODO: replace with strings(https://github.com/pytorch/pytorch/pull/26330/files#r324951460)
|
||||
engine = ContextProp(torch._C._get_qengine, torch._C._set_qengine)
|
||||
|
||||
# This is the sys.modules replacement trick, see
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@
|
|||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/Layout.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <torch/csrc/QEngine.h>
|
||||
#include <torch/csrc/QScheme.h>
|
||||
#include <torch/csrc/TypeInfo.h>
|
||||
#include <torch/csrc/autograd/generated/python_nn_functions.h>
|
||||
|
|
@ -37,7 +36,6 @@
|
|||
#include <torch/csrc/tensor/python_tensor.h>
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/qengines.h>
|
||||
#include <torch/csrc/utils/tensor_layouts.h>
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
#include <torch/csrc/utils/tensor_qschemes.h>
|
||||
|
|
@ -110,7 +108,6 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
|
|||
torch::utils::initializeLayouts();
|
||||
torch::utils::initializeMemoryFormats();
|
||||
torch::utils::initializeQSchemes();
|
||||
torch::utils::initializeQEngines();
|
||||
torch::utils::initializeDtypes();
|
||||
torch::tensors::initialize_python_bindings();
|
||||
std::string path = THPUtils_unpackString(shm_manager_path);
|
||||
|
|
@ -491,15 +488,16 @@ PyObject *THPModule_getDefaultDevice(PyObject *_unused, PyObject *arg) {
|
|||
|
||||
PyObject *THPModule_setQEngine(PyObject */* unused */, PyObject *arg)
|
||||
{
|
||||
TORCH_CHECK(THPQEngine_Check(arg), "qengine arg must be an instance of the torch.qengine");
|
||||
const auto qengine = reinterpret_cast<THPQEngine*>(arg);
|
||||
at::globalContext().setQEngine(qengine->qengine);
|
||||
THPUtils_assert(THPUtils_checkLong(arg), "set_qengine expects an int, "
|
||||
"but got %s", THPUtils_typename(arg));
|
||||
auto qengine = static_cast<int>(THPUtils_unpackLong(arg));
|
||||
at::globalContext().setQEngine(static_cast<at::QEngine>(qengine));
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *THPModule_qEngine(PyObject */* unused */)
|
||||
{
|
||||
return THPQEngine_New(at::globalContext().qEngine(), toString(at::globalContext().qEngine()));
|
||||
return THPUtils_packInt64(static_cast<int>(at::globalContext().qEngine()));
|
||||
}
|
||||
|
||||
PyObject *THPModule_supportedQEngines(PyObject */* unused */)
|
||||
|
|
@ -695,7 +693,6 @@ PyObject* initModule() {
|
|||
THPDTypeInfo_init(module);
|
||||
THPLayout_init(module);
|
||||
THPMemoryFormat_init(module);
|
||||
THPQEngine_init(module);
|
||||
THPQScheme_init(module);
|
||||
THPDevice_init(module);
|
||||
ASSERT_TRUE(THPVariable_initModule(module));
|
||||
|
|
|
|||
|
|
@ -1,79 +0,0 @@
|
|||
#include <torch/csrc/QEngine.h>
|
||||
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
#include <c10/core/QEngine.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
PyObject* THPQEngine_New(at::QEngine qengine, const std::string& name) {
|
||||
auto type = (PyTypeObject*)&THPQEngineType;
|
||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||
if (!self)
|
||||
throw python_error();
|
||||
auto self_ = reinterpret_cast<THPQEngine*>(self.get());
|
||||
self_->qengine = qengine;
|
||||
std::strncpy(self_->name, name.c_str(), QENGINE_NAME_LEN);
|
||||
self_->name[QENGINE_NAME_LEN] = '\0';
|
||||
return self.release();
|
||||
}
|
||||
|
||||
PyObject* THPQEngine_repr(THPQEngine* self) {
|
||||
std::string name = self->name;
|
||||
return THPUtils_packString("torch." + name);
|
||||
}
|
||||
|
||||
PyTypeObject THPQEngineType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "torch.qengine", /* tp_name */
|
||||
sizeof(THPQEngine), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
nullptr, /* tp_dealloc */
|
||||
nullptr, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
(reprfunc)THPQEngine_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
nullptr, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
nullptr, /* tp_new */
|
||||
};
|
||||
|
||||
void THPQEngine_init(PyObject* module) {
|
||||
if (PyType_Ready(&THPQEngineType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPQEngineType);
|
||||
if (PyModule_AddObject(module, "qengine", (PyObject*)&THPQEngineType) !=
|
||||
0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <c10/core/QEngine.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
constexpr int QENGINE_NAME_LEN = 64;
|
||||
|
||||
struct THPQEngine {
|
||||
PyObject_HEAD at::QEngine qengine;
|
||||
char name[QENGINE_NAME_LEN + 1];
|
||||
};
|
||||
|
||||
extern PyTypeObject THPQEngineType;
|
||||
|
||||
inline bool THPQEngine_Check(PyObject* obj) {
|
||||
return Py_TYPE(obj) == &THPQEngineType;
|
||||
}
|
||||
|
||||
PyObject* THPQEngine_New(at::QEngine qengine, const std::string& name);
|
||||
|
||||
void THPQEngine_init(PyObject* module);
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
#include <torch/csrc/utils/qengines.h>
|
||||
|
||||
#include <c10/core/QEngine.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/QEngine.h>
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
void addQEngine(
|
||||
at::QEngine qengine,
|
||||
const std::string& name,
|
||||
PyObject* torch_module) {
|
||||
PyObject* qengine_obj = THPQEngine_New(qengine, name);
|
||||
Py_INCREF(qengine_obj);
|
||||
if (PyModule_AddObject(torch_module, name.c_str(), qengine_obj) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
void initializeQEngines() {
|
||||
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
|
||||
if (!torch_module) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
addQEngine(at::kNoQEngine, "no_qengine", torch_module);
|
||||
addQEngine(at::kFBGEMM, "fbgemm", torch_module);
|
||||
addQEngine(at::kQNNPACK, "qnnpack", torch_module);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/QEngine.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
void initializeQEngines();
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
Loading…
Reference in a new issue