Update qengine flag in python to string (#26620)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26620

This change updates torch.backend.quantized.engine to accept string ("fbgemm"/"qnnpack"/"none" for now).
set_qengine and get_qengine return an int which represents the at::QEngine enum

Test Plan:
python test/test_torch.py

Imported from OSS

Differential Revision: D17533582

fbshipit-source-id: 5103263d0d59ff37d43dec27243cb76ba8ba633f
This commit is contained in:
Supriya Rao 2019-09-23 17:55:18 -07:00 committed by Facebook Github Bot
parent 5d82cefa55
commit 45391ccecb
12 changed files with 47 additions and 175 deletions

View file

@ -105,7 +105,7 @@ void Context::setQEngine(at::QEngine e) {
quantized_engine = e;
return;
}
TORCH_CHECK(false, "quantized engine ", toString(e), "is not supported");
TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
}
std::vector<at::QEngine> Context::supportedQEngines() const {

View file

@ -8,6 +8,8 @@ namespace c10 {
/**
* QEngine is an enum that is used to select the engine to run quantized ops.
* Keep this enum in sync with get_qengine_id() in
* torch/backends/quantized/__init__.py
*/
enum class QEngine : uint8_t {
NoQEngine = 0,

View file

@ -65,8 +65,12 @@ def _calculate_dynamic_qparams(X, dtype):
@contextmanager
def enable_mobile_quantized_engine():
torch.backends.quantized.engine = torch.qnnpack
torch.backends.quantized.engine = 'qnnpack'
try:
yield
finally:
torch.backends.quantized.engine = torch.fbgemm
qengines = torch.backends.quantized.get_supported_qengines()
if 'fbgemm' in qengines:
torch.backends.quantized.engine = 'fbgemm'
else:
torch.backends.quantized.engine = 'none'

View file

@ -36,6 +36,7 @@ from multiprocessing.reduction import ForkingPickler
from common_device_type import instantiate_device_type_tests, \
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
dtypes, dtypesIfCUDA
import torch.backends.quantized
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -2114,13 +2115,12 @@ class _TestTorchMixin(object):
test_inference(torch.float32)
def test_qengnie(self):
qengines = torch._C._supported_qengines()
# [TODO] Enable after the interface change
# original_qe = torch._C._get_qengine()
# for qe in qengines:
# torch._C._set_qengine(qe)
# assert torch._C._get_qengine() == qe, 'qengine not set successfully'
# torch._C._set_qengine(original_qe)
qengines = torch.backends.quantized.get_supported_qengines()
original_qe = torch.backends.quantized.engine
for qe in qengines:
torch.backends.quantized.engine = qe
assert torch.backends.quantized.engine == qe, 'qengine not set successfully'
torch.backends.quantized.engine = original_qe
def test_new_tensor(self):
expected = torch.autograd.Variable(torch.ByteTensor([1, 1]))

View file

@ -230,7 +230,6 @@ def add_torch_libs():
"torch/csrc/Generator.cpp",
"torch/csrc/Layout.cpp",
"torch/csrc/MemoryFormat.cpp",
"torch/csrc/QEngine.cpp",
"torch/csrc/QScheme.cpp",
"torch/csrc/Module.cpp",
"torch/csrc/PtrWrapper.cpp",
@ -293,7 +292,6 @@ def add_torch_libs():
"torch/csrc/utils/invalid_arguments.cpp",
"torch/csrc/utils/object_ptr.cpp",
"torch/csrc/utils/python_arg_parser.cpp",
"torch/csrc/utils/qengines.cpp",
"torch/csrc/utils/structseq.cpp",
"torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_dtypes.cpp",

View file

@ -52,7 +52,6 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/Layout.cpp
${TORCH_SRC_DIR}/csrc/MemoryFormat.cpp
${TORCH_SRC_DIR}/csrc/python_dimname.cpp
${TORCH_SRC_DIR}/csrc/QEngine.cpp
${TORCH_SRC_DIR}/csrc/QScheme.cpp
${TORCH_SRC_DIR}/csrc/Module.cpp
${TORCH_SRC_DIR}/csrc/PtrWrapper.cpp
@ -96,7 +95,6 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/utils/invalid_arguments.cpp
${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp
${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp
${TORCH_SRC_DIR}/csrc/utils/qengines.cpp
${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp

View file

@ -3,16 +3,40 @@ import sys
import torch
import types
# This function should correspond to the enums present in c10/core/QEngine.h
def get_qengine_id(qengine):
# type: (str) -> int
if qengine == 'none':
ret = 0
elif qengine == 'fbgemm':
ret = 1
elif qengine == 'qnnpack':
ret = 2
else:
ret = -1
raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
return ret
# This function should correspond to the enums present in c10/core/QEngine.h
def get_qengine_str(qengine):
# type: (int) -> str
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'}
return all_engines.get(qengine)
def get_supported_qengines():
qengines = torch._C._supported_qengines()
return [get_qengine_str(qe) for qe in qengines]
class ContextProp(object):
def __init__(self, getter, setter):
self.getter = getter
self.setter = setter
def __get__(self, obj, objtype):
return self.getter()
return get_qengine_str(self.getter())
def __set__(self, obj, val):
self.setter(val)
self.setter(get_qengine_id(val))
class QuantizedEngine(types.ModuleType):
def __init__(self, m, name):
@ -21,7 +45,6 @@ class QuantizedEngine(types.ModuleType):
def __getattr__(self, attr):
return self.m.__getattribute__(attr)
# TODO: replace with strings(https://github.com/pytorch/pytorch/pull/26330/files#r324951460)
engine = ContextProp(torch._C._get_qengine, torch._C._set_qengine)
# This is the sys.modules replacement trick, see

View file

@ -27,7 +27,6 @@
#include <torch/csrc/Generator.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/MemoryFormat.h>
#include <torch/csrc/QEngine.h>
#include <torch/csrc/QScheme.h>
#include <torch/csrc/TypeInfo.h>
#include <torch/csrc/autograd/generated/python_nn_functions.h>
@ -37,7 +36,6 @@
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/qengines.h>
#include <torch/csrc/utils/tensor_layouts.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
#include <torch/csrc/utils/tensor_qschemes.h>
@ -110,7 +108,6 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
torch::utils::initializeLayouts();
torch::utils::initializeMemoryFormats();
torch::utils::initializeQSchemes();
torch::utils::initializeQEngines();
torch::utils::initializeDtypes();
torch::tensors::initialize_python_bindings();
std::string path = THPUtils_unpackString(shm_manager_path);
@ -491,15 +488,16 @@ PyObject *THPModule_getDefaultDevice(PyObject *_unused, PyObject *arg) {
PyObject *THPModule_setQEngine(PyObject */* unused */, PyObject *arg)
{
TORCH_CHECK(THPQEngine_Check(arg), "qengine arg must be an instance of the torch.qengine");
const auto qengine = reinterpret_cast<THPQEngine*>(arg);
at::globalContext().setQEngine(qengine->qengine);
THPUtils_assert(THPUtils_checkLong(arg), "set_qengine expects an int, "
"but got %s", THPUtils_typename(arg));
auto qengine = static_cast<int>(THPUtils_unpackLong(arg));
at::globalContext().setQEngine(static_cast<at::QEngine>(qengine));
Py_RETURN_NONE;
}
PyObject *THPModule_qEngine(PyObject */* unused */)
{
return THPQEngine_New(at::globalContext().qEngine(), toString(at::globalContext().qEngine()));
return THPUtils_packInt64(static_cast<int>(at::globalContext().qEngine()));
}
PyObject *THPModule_supportedQEngines(PyObject */* unused */)
@ -695,7 +693,6 @@ PyObject* initModule() {
THPDTypeInfo_init(module);
THPLayout_init(module);
THPMemoryFormat_init(module);
THPQEngine_init(module);
THPQScheme_init(module);
THPDevice_init(module);
ASSERT_TRUE(THPVariable_initModule(module));

View file

@ -1,79 +0,0 @@
#include <torch/csrc/QEngine.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_strings.h>
#include <c10/core/QEngine.h>
#include <structmember.h>
#include <cstring>
#include <string>
PyObject* THPQEngine_New(at::QEngine qengine, const std::string& name) {
auto type = (PyTypeObject*)&THPQEngineType;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self)
throw python_error();
auto self_ = reinterpret_cast<THPQEngine*>(self.get());
self_->qengine = qengine;
std::strncpy(self_->name, name.c_str(), QENGINE_NAME_LEN);
self_->name[QENGINE_NAME_LEN] = '\0';
return self.release();
}
PyObject* THPQEngine_repr(THPQEngine* self) {
std::string name = self->name;
return THPUtils_packString("torch." + name);
}
PyTypeObject THPQEngineType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch.qengine", /* tp_name */
sizeof(THPQEngine), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
nullptr, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPQEngine_repr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
nullptr, /* tp_new */
};
void THPQEngine_init(PyObject* module) {
if (PyType_Ready(&THPQEngineType) < 0) {
throw python_error();
}
Py_INCREF(&THPQEngineType);
if (PyModule_AddObject(module, "qengine", (PyObject*)&THPQEngineType) !=
0) {
throw python_error();
}
}

View file

@ -1,24 +0,0 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <c10/core/QEngine.h>
#include <string>
constexpr int QENGINE_NAME_LEN = 64;
struct THPQEngine {
PyObject_HEAD at::QEngine qengine;
char name[QENGINE_NAME_LEN + 1];
};
extern PyTypeObject THPQEngineType;
inline bool THPQEngine_Check(PyObject* obj) {
return Py_TYPE(obj) == &THPQEngineType;
}
PyObject* THPQEngine_New(at::QEngine qengine, const std::string& name);
void THPQEngine_init(PyObject* module);

View file

@ -1,37 +0,0 @@
#include <torch/csrc/utils/qengines.h>
#include <c10/core/QEngine.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/QEngine.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
namespace torch {
namespace utils {
void addQEngine(
at::QEngine qengine,
const std::string& name,
PyObject* torch_module) {
PyObject* qengine_obj = THPQEngine_New(qengine, name);
Py_INCREF(qengine_obj);
if (PyModule_AddObject(torch_module, name.c_str(), qengine_obj) != 0) {
throw python_error();
}
}
void initializeQEngines() {
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
if (!torch_module) {
throw python_error();
}
addQEngine(at::kNoQEngine, "no_qengine", torch_module);
addQEngine(at::kFBGEMM, "fbgemm", torch_module);
addQEngine(at::kQNNPACK, "qnnpack", torch_module);
}
} // namespace utils
} // namespace torch

View file

@ -1,10 +0,0 @@
#pragma once
#include <torch/csrc/QEngine.h>
namespace torch {
namespace utils {
void initializeQEngines();
} // namespace utils
} // namespace torch