mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Remove fbgemm_is_cpu_supported in favor of torch.backends.quantized.supported_qengines (#26840)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26840 Cleaning up top-level namespace. Also cosmetic changes to torch.backends.quantized Test Plan: Imported from OSS Differential Revision: D17604403 Pulled By: dzhulgakov fbshipit-source-id: c55af277ea7319d962a82a6120f65ccd47a60abc
This commit is contained in:
parent
e4fba752cb
commit
764bf826e3
19 changed files with 127 additions and 139 deletions
|
|
@ -4,26 +4,30 @@
|
|||
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/cpu/FlushDenormal.h>
|
||||
|
||||
#include <TH/TH.h> // for USE_LAPACK
|
||||
#include <TH/TH.h> // for USE_LAPACK
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
#include "fbgemm/Fbgemm.h"
|
||||
#endif // USE_FBGEMM
|
||||
|
||||
namespace at {
|
||||
|
||||
Context::Context()
|
||||
: thc_state(nullptr, [](THCState* p){ /* no-op */ } )
|
||||
, thh_state(nullptr, [](THHState* p){ /* no-op */ } ) {}
|
||||
: thc_state(nullptr, [](THCState* p) { /* no-op */ }),
|
||||
thh_state(nullptr, [](THHState* p) { /* no-op */ }) {}
|
||||
|
||||
// TODO: This could be bad juju if someone calls globalContext() in the
|
||||
// destructor of an object with static lifetime.
|
||||
Context & globalContext() {
|
||||
Context& globalContext() {
|
||||
static Context globalContext_;
|
||||
return globalContext_;
|
||||
}
|
||||
|
|
@ -96,7 +100,8 @@ bool Context::hasLAPACK() const {
|
|||
}
|
||||
|
||||
at::QEngine Context::qEngine() const {
|
||||
return quantized_engine;
|
||||
// If wasn't explicitly set - take the last one available
|
||||
return quantized_engine.value_or(supportedQEngines().back());
|
||||
}
|
||||
|
||||
void Context::setQEngine(at::QEngine e) {
|
||||
|
|
@ -108,16 +113,31 @@ void Context::setQEngine(at::QEngine e) {
|
|||
TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
|
||||
}
|
||||
|
||||
std::vector<at::QEngine> Context::supportedQEngines() const {
|
||||
static auto supported_qengines = {
|
||||
at::kNoQEngine,
|
||||
#ifdef USE_FBGEMM
|
||||
at::kFBGEMM,
|
||||
#endif
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
at::kQNNPACK,
|
||||
#endif
|
||||
};
|
||||
const std::vector<at::QEngine>& Context::supportedQEngines() const {
|
||||
static auto supported_qengines = []() {
|
||||
std::vector<at::QEngine> engines = {};
|
||||
// Engines are listed in priority order: later one wins
|
||||
// By default we prefer FBGEMM if we're running on server side
|
||||
// QNNPACK on server side has some issue, so we disable it by default.
|
||||
#ifdef C10_MOBILE
|
||||
engines.push_back(at::kNoQEngine);
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
engines.push_back(at::kQNNPACK);
|
||||
#endif
|
||||
#else // C10_MOBILE
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
engines.push_back(at::kQNNPACK);
|
||||
#endif
|
||||
engines.push_back(at::kNoQEngine);
|
||||
#endif // C10_MOBILE
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (fbgemm::fbgemmSupportedCPU()) {
|
||||
engines.push_back(at::kFBGEMM);
|
||||
}
|
||||
#endif
|
||||
return engines;
|
||||
}();
|
||||
return supported_qengines;
|
||||
}
|
||||
|
||||
|
|
@ -143,4 +163,4 @@ struct LegacyDeviceTypeInit : public LegacyDeviceTypeInitInterface {
|
|||
};
|
||||
REGISTER_LEGACY_TYPE_INIT(LegacyDeviceTypeInit);
|
||||
|
||||
}
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class CAFFE2_API Context {
|
|||
void setDeterministicCuDNN(bool);
|
||||
at::QEngine qEngine() const;
|
||||
void setQEngine(at::QEngine e);
|
||||
std::vector<at::QEngine> supportedQEngines() const;
|
||||
const std::vector<at::QEngine>& supportedQEngines() const;
|
||||
|
||||
private:
|
||||
void initCUDAIfNeeded(DeviceType p) {
|
||||
|
|
@ -127,12 +127,7 @@ class CAFFE2_API Context {
|
|||
bool deterministic_cudnn = false;
|
||||
bool benchmark_cudnn = false;
|
||||
bool enabled_mkldnn = true;
|
||||
at::QEngine quantized_engine =
|
||||
#ifdef USE_FBGEMM
|
||||
at::kFBGEMM;
|
||||
#else
|
||||
at::kNoQEngine;
|
||||
#endif
|
||||
c10::optional<at::QEngine> quantized_engine = c10::nullopt;
|
||||
std::unique_ptr<THCState, void(*)(THCState*)> thc_state;
|
||||
std::unique_ptr<THHState, void(*)(THHState*)> thh_state;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -219,7 +219,6 @@ bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) {
|
|||
{"aten::fbgemm_linear_fp16_weight", ""},
|
||||
{"aten::fbgemm_pack_quantized_matrix", ""},
|
||||
{"aten::fbgemm_pack_quantized_matrix", "KN"},
|
||||
{"aten::fbgemm_is_cpu_supported", ""},
|
||||
{"aten::log", ""},
|
||||
{"aten::log_", ""},
|
||||
{"aten::log10", ""},
|
||||
|
|
|
|||
|
|
@ -252,10 +252,6 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
|
|||
quantized, col_offsets, q_params.scale, q_params.zero_point);
|
||||
}
|
||||
|
||||
bool fbgemm_is_cpu_supported() {
|
||||
return fbgemm::fbgemmSupportedCPU();
|
||||
}
|
||||
|
||||
Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
|
|
|
|||
|
|
@ -1506,9 +1506,6 @@
|
|||
- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: fbgemm_is_cpu_supported() -> bool
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
- func: linspace.out(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -185,11 +185,12 @@ Tensor qnnpack_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
|
|||
public:
|
||||
Tensor operator()(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
|
||||
check_inputs(qa, qb);
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
|
||||
qa.scalar_type() == kQUInt8 && qb.scalar_type() == kQUInt8) {
|
||||
return qnnpack_add(qa, qb, scale, zero_point);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
auto qc = at::_empty_affine_quantized(
|
||||
qa.sizes(),
|
||||
at::device(kCPU).dtype(qa.scalar_type()),
|
||||
|
|
|
|||
|
|
@ -401,7 +401,7 @@ class QMaxPool2D_arr_args final : public torch::OperatorKernel {
|
|||
std::vector<int64_t> dilation,
|
||||
bool ceil_mode) {
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
|
||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) {
|
||||
return qnnpack_maxpool(qx, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ Tensor qnnpack_relu(Tensor input) {
|
|||
|
||||
Tensor quantized_relu(const Tensor& qx) {
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
|
||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) {
|
||||
return qnnpack_relu(qx);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from torch._C import parse_schema
|
|||
white_list = [
|
||||
('quantize', datetime.date(2019, 10, 1)),
|
||||
('q_per_channel_axis', datetime.date(2019, 10, 1)),
|
||||
('fbgemm_is_cpu_supported', datetime.date(2019, 10, 1)),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -65,12 +65,9 @@ def _calculate_dynamic_qparams(X, dtype):
|
|||
|
||||
@contextmanager
|
||||
def enable_mobile_quantized_engine():
|
||||
previous = torch.backends.quantized.engine
|
||||
torch.backends.quantized.engine = 'qnnpack'
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
qengines = torch.backends.quantized.get_supported_qengines()
|
||||
if 'fbgemm' in qengines:
|
||||
torch.backends.quantized.engine = 'fbgemm'
|
||||
else:
|
||||
torch.backends.quantized.engine = 'none'
|
||||
torch.backends.quantized.engine = previous
|
||||
|
|
|
|||
|
|
@ -7210,9 +7210,9 @@ a")
|
|||
a = A()
|
||||
self.assertEqual(a.with_docstring.__doc__, 'test str')
|
||||
|
||||
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
|
||||
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
def test_rnn_cell_quantized(self):
|
||||
d_in, d_hid = 2, 2
|
||||
|
||||
|
|
@ -7304,9 +7304,9 @@ a")
|
|||
for out, ref_out in zip(outs, ref_outs):
|
||||
torch.testing.assert_allclose(out, ref_out)
|
||||
|
||||
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
|
||||
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
def test_rnn_quantized(self):
|
||||
d_in, d_hid = 2, 2
|
||||
|
||||
|
|
@ -12378,7 +12378,7 @@ a")
|
|||
|
||||
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
|
||||
|
||||
if torch.fbgemm_is_cpu_supported():
|
||||
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
||||
def test_quantization_modules(self):
|
||||
K1, N1 = 2, 2
|
||||
|
||||
|
|
@ -15189,7 +15189,7 @@ a")
|
|||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(), "requires FBGEMM")
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, "requires FBGEMM")
|
||||
def test_erase_class_tensor_shapes(self):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
|
|
@ -16064,7 +16064,7 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||
def test_snli(self):
|
||||
self._test_snli(self, device='cpu')
|
||||
|
||||
if torch.fbgemm_is_cpu_supported():
|
||||
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
||||
def test_snli_quantized(self):
|
||||
self._test_snli(self, device='cpu', quantized=True)
|
||||
|
||||
|
|
@ -16206,7 +16206,7 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||
def test_vae(self):
|
||||
self._test_vae(self, device='cpu')
|
||||
|
||||
if torch.fbgemm_is_cpu_supported():
|
||||
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
||||
def test_vae_quantized(self):
|
||||
self._test_vae(self, device='cpu', quantized=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -2459,9 +2459,9 @@ class TestNN(NNTestCase):
|
|||
# should be bitwise equal
|
||||
self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0)
|
||||
|
||||
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
|
||||
'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
def test_fb_fc_packed(self):
|
||||
X = np.random.rand(16, 16).astype(np.float32) - 0.5
|
||||
W = np.random.rand(16, 16).astype(np.float32) - 0.5
|
||||
|
|
|
|||
|
|
@ -30,11 +30,9 @@ from hypothesis_utils import no_deadline
|
|||
import io
|
||||
import copy
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class PostTrainingQuantTest(QuantizationTestCase):
|
||||
def test_single_layer(self):
|
||||
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
|
||||
|
|
@ -292,11 +290,9 @@ class PostTrainingQuantTest(QuantizationTestCase):
|
|||
|
||||
checkQuantized(model)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class PostTrainingDynamicQuantTest(QuantizationTestCase):
|
||||
def test_single_layer(self):
|
||||
r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
|
||||
|
|
@ -569,11 +565,9 @@ class PostTrainingDynamicQuantTest(QuantizationTestCase):
|
|||
for out, ref in zip(final_hiddens_fp16, ref_hid):
|
||||
torch.testing.assert_allclose(out, ref)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class QuantizationAwareTrainingTest(QuantizationTestCase):
|
||||
def test_manual(self):
|
||||
model = ManualLinearQATModel()
|
||||
|
|
@ -656,10 +650,9 @@ class ScriptabilityTest(QuantizationTestCase):
|
|||
self.checkScriptable(self.qmodel_under_test, [(xq, xq)], check_save_load=True)
|
||||
self.checkScriptable(self.model_under_test, [(xq.dequantize(), xq.dequantize())], check_save_load=True)
|
||||
|
||||
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
|
||||
'Quantization requires FBGEMM. FBGEMM does not play'
|
||||
' well with UBSAN at the moment, so we skip the test if'
|
||||
' we are in a UBSAN environment.')
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class FusionTest(QuantizationTestCase):
|
||||
def test_fuse_module_train(self):
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
|
|
@ -901,10 +894,9 @@ class ObserverTest(QuantizationTestCase):
|
|||
loaded = torch.jit.load(buf)
|
||||
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
|
||||
|
||||
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
|
||||
'Quantization requires FBGEMM. FBGEMM does not play'
|
||||
' well with UBSAN at the moment, so we skip the test if'
|
||||
' we are in a UBSAN environment.')
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class RecordHistogramObserverTest(QuantizationTestCase):
|
||||
def test_record_observer(self):
|
||||
model = SingleLayerLinearModel()
|
||||
|
|
|
|||
|
|
@ -967,11 +967,9 @@ class TestQuantizedOps(TestCase):
|
|||
self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class TestDynamicQuantizedLinear(TestCase):
|
||||
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
|
||||
@no_deadline
|
||||
|
|
@ -1086,11 +1084,9 @@ class TestDynamicQuantizedLinear(TestCase):
|
|||
self.assertEqual(Y_fp32, Y_fp32_ref,
|
||||
message="torch.ops.quantized.linear_dynamic (fbgemm) results are off")
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class TestQuantizedLinear(unittest.TestCase):
|
||||
"""Tests the correctness of the quantized linear and linear_relu op."""
|
||||
@no_deadline
|
||||
|
|
@ -1264,11 +1260,9 @@ class TestQuantizedLinear(unittest.TestCase):
|
|||
W_q.q_zero_point(), W_q_origin.q_zero_point())
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class TestQuantizedConv(unittest.TestCase):
|
||||
"""Tests the correctness of quantized convolution op."""
|
||||
@given(batch_size=st.integers(1, 3),
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
import torch
|
||||
import torch.jit
|
||||
import unittest
|
||||
from common_utils import run_tests
|
||||
from common_quantization import QuantizationTestCase, ModelMultipleOps
|
||||
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
"Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
class ModelNumerics(QuantizationTestCase):
|
||||
def test_float_quant_compare(self):
|
||||
torch.manual_seed(42)
|
||||
|
|
|
|||
|
|
@ -34,11 +34,9 @@ class FunctionalAPITest(QuantizationTestCase):
|
|||
self.assertEqual(qY, qY_hat)
|
||||
|
||||
@no_deadline
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
@given(
|
||||
use_bias=st.booleans(),
|
||||
)
|
||||
|
|
@ -89,11 +87,9 @@ class FunctionalAPITest(QuantizationTestCase):
|
|||
|
||||
class DynamicModuleAPITest(QuantizationTestCase):
|
||||
@no_deadline
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
@given(
|
||||
batch_size=st.integers(1, 5),
|
||||
in_features=st.integers(16, 32),
|
||||
|
|
@ -209,11 +205,9 @@ class ModuleAPITest(QuantizationTestCase):
|
|||
|
||||
|
||||
@no_deadline
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
@given(
|
||||
batch_size=st.integers(1, 5),
|
||||
in_features=st.integers(16, 32),
|
||||
|
|
@ -341,11 +335,9 @@ class ModuleAPITest(QuantizationTestCase):
|
|||
self.assertEqual(rqr, rqr2)
|
||||
|
||||
@no_deadline
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
@given(
|
||||
use_bias=st.booleans(),
|
||||
use_fused=st.booleans(),
|
||||
|
|
|
|||
|
|
@ -36,11 +36,9 @@ class WeightObserver(Observer):
|
|||
super(WeightObserver, self).__init__()
|
||||
self.dtype = torch.qint8
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.fbgemm_is_cpu_supported(),
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.",
|
||||
)
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
@unittest.skip("temoprarily disable the test")
|
||||
class QuantizerTestCase(TestCase):
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
|
|
|
|||
|
|
@ -2070,8 +2070,8 @@ class _TestTorchMixin(object):
|
|||
test_inference(torch.float64)
|
||||
test_inference(torch.float32)
|
||||
|
||||
def test_qengnie(self):
|
||||
qengines = torch.backends.quantized.get_supported_qengines()
|
||||
def test_qengine(self):
|
||||
qengines = torch.backends.quantized.supported_engines
|
||||
original_qe = torch.backends.quantized.engine
|
||||
for qe in qengines:
|
||||
torch.backends.quantized.engine = qe
|
||||
|
|
@ -5356,19 +5356,19 @@ class _TestTorchMixin(object):
|
|||
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
|
||||
|
||||
def test_load_unicode_error_msg(self):
|
||||
# This Pickle contains a Python 2 module with Unicode data and the
|
||||
# This Pickle contains a Python 2 module with Unicode data and the
|
||||
# loading should fail if the user explicitly specifies ascii encoding!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
if sys.version_info >= (3, 0):
|
||||
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
|
||||
else:
|
||||
# Just checks the module loaded
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
|
||||
def test_load_python2_unicode_module(self):
|
||||
# This Pickle contains some Unicode data!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
|
||||
def test_load_error_msg(self):
|
||||
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ import torch
|
|||
import types
|
||||
|
||||
# This function should correspond to the enums present in c10/core/QEngine.h
|
||||
def get_qengine_id(qengine):
|
||||
def _get_qengine_id(qengine):
|
||||
# type: (str) -> int
|
||||
if qengine == 'none':
|
||||
if qengine == 'none' or qengine == '' or qengine is None:
|
||||
ret = 0
|
||||
elif qengine == 'fbgemm':
|
||||
ret = 1
|
||||
|
|
@ -18,25 +18,25 @@ def get_qengine_id(qengine):
|
|||
return ret
|
||||
|
||||
# This function should correspond to the enums present in c10/core/QEngine.h
|
||||
def get_qengine_str(qengine):
|
||||
def _get_qengine_str(qengine):
|
||||
# type: (int) -> str
|
||||
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'}
|
||||
return all_engines.get(qengine)
|
||||
|
||||
def get_supported_qengines():
|
||||
qengines = torch._C._supported_qengines()
|
||||
return [get_qengine_str(qe) for qe in qengines]
|
||||
|
||||
class ContextProp(object):
|
||||
def __init__(self, getter, setter):
|
||||
self.getter = getter
|
||||
self.setter = setter
|
||||
|
||||
class _QEngineProp(object):
|
||||
def __get__(self, obj, objtype):
|
||||
return get_qengine_str(self.getter())
|
||||
return _get_qengine_str(torch._C._get_qengine())
|
||||
|
||||
def __set__(self, obj, val):
|
||||
self.setter(get_qengine_id(val))
|
||||
torch._C._set_qengine(_get_qengine_id(val))
|
||||
|
||||
class _SupportedQEnginesProp(object):
|
||||
def __get__(self, obj, objtype):
|
||||
qengines = torch._C._supported_qengines()
|
||||
return [_get_qengine_str(qe) for qe in qengines]
|
||||
|
||||
def __set__(self, obj, val):
|
||||
raise RuntimeError("Assignment not supported")
|
||||
|
||||
class QuantizedEngine(types.ModuleType):
|
||||
def __init__(self, m, name):
|
||||
|
|
@ -45,7 +45,9 @@ class QuantizedEngine(types.ModuleType):
|
|||
|
||||
def __getattr__(self, attr):
|
||||
return self.m.__getattribute__(attr)
|
||||
engine = ContextProp(torch._C._get_qengine, torch._C._set_qengine)
|
||||
|
||||
engine = _QEngineProp()
|
||||
supported_engines = _SupportedQEnginesProp()
|
||||
|
||||
# This is the sys.modules replacement trick, see
|
||||
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
|
||||
|
|
|
|||
Loading…
Reference in a new issue