diff --git a/buckbuild.bzl b/buckbuild.bzl index ae1519ea8f5..76e0db976be 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -1417,6 +1417,7 @@ def define_buck_targets( "torch/csrc/autograd/VariableTypeManual.cpp", "torch/csrc/autograd/FunctionsManual.cpp", "torch/csrc/api/src/data/datasets/mnist.cpp", + "torch/csrc/jit/mobile/quantization.cpp", "torch/csrc/jit/mobile/train/export_data.cpp", "torch/csrc/jit/mobile/train/optim/sgd.cpp", "torch/csrc/jit/mobile/train/random.cpp", diff --git a/build_variables.bzl b/build_variables.bzl index eb09a2a5f59..ec08b918c5f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -564,6 +564,7 @@ torch_mobile_core = [ "torch/csrc/jit/mobile/observer.cpp", "torch/csrc/jit/mobile/parse_bytecode.cpp", "torch/csrc/jit/mobile/parse_operators.cpp", + "torch/csrc/jit/mobile/quantization.cpp", "torch/csrc/jit/mobile/upgrader_mobile.cpp", "torch/csrc/jit/runtime/register_prim_ops.cpp", "torch/csrc/jit/runtime/register_special_ops.cpp", @@ -612,6 +613,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ "torch/csrc/jit/mobile/observer.cpp", "torch/csrc/jit/mobile/parse_bytecode.cpp", "torch/csrc/jit/mobile/parse_operators.cpp", + "torch/csrc/jit/mobile/quantization.cpp", "torch/csrc/jit/mobile/train/export_data.cpp", "torch/csrc/jit/mobile/train/optim/sgd.cpp", "torch/csrc/jit/mobile/train/random.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index a9048980401..584d550b2e8 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp + ${TORCH_SRC_DIR}/csrc/jit/mobile/quantization.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp diff --git a/test/quantization/jit/test_ondevice_quantization.py b/test/quantization/jit/test_ondevice_quantization.py index 8b453a50f1a..fa3cfaab24b 100644 --- a/test/quantization/jit/test_ondevice_quantization.py +++ b/test/quantization/jit/test_ondevice_quantization.py @@ -2,6 +2,7 @@ # Owner(s): ["oncall: quantization"] import torch +import torch._C_flatbuffer from torch.ao.quantization import ( default_dynamic_qconfig, @@ -22,11 +23,13 @@ from torch.testing._internal.common_quantization import ( LinearAddModel, ) -from torch.jit.mobile import _load_for_lite_interpreter +from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule from torch.testing import FileCheck +from torch.utils import bundled_inputs as bundled_inputs import io +from typing import Dict class myMod(torch.nn.Module): def __init__(self, weight): @@ -396,7 +399,7 @@ class TestOnDeviceDynamicPTQFinalize(TestCase): self.assertTrue(thrown) - def _check_serialization_deserialization(self, model): + def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False): model.eval() inputs = model.get_example_inputs() ref_m = torch.jit.script(model) @@ -410,27 +413,40 @@ class TestOnDeviceDynamicPTQFinalize(TestCase): ref_m = torch.jit.load(buffer) ref_output = ref_m(*inputs) - m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) - buffer = io.BytesIO() - torch.jit.save(m, buffer) - buffer.seek(0) - m = torch.jit.load(buffer) - m.reset_observers_forward() - m.observe_forward(*inputs) - m.quantize_forward(*inputs) - output = m.quantized_forward(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) + if not check_device_side_api: + m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + m = torch.jit.load(buffer) + m.reset_observers_forward() + m.observe_forward(*inputs) + m.quantize_forward(*inputs) + output = m.quantized_forward(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) + else: + # check for lite interpreter + m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) + first_input, = inputs + rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype) + m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )]) + buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + m = _load_for_lite_interpreter(buffer) # Error here + torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward") + self.assertFalse(m.find_method("quantized_forward")) + self.assertFalse(m.find_method("quantize_forward")) + self.assertFalse(m.find_method("observe_forward")) + self.assertFalse(m.find_method("reset_observers_forward")) + output = m(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) - # check for lite interpreter - m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) - buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) - buffer.seek(0) - m = _load_for_lite_interpreter(buffer) # Error here - m.run_method("reset_observers_forward") - m.run_method("observe_forward", *inputs) - m.run_method("quantize_forward", *inputs) - output = m.run_method("quantized_forward", *inputs) - self.assertTrue(torch.allclose(ref_output, output)) + # Now serialize to flabuffer and load from fb and check + dict: Dict[str, str] = {} + bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict) + m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes)) + fb_output = m(*inputs) + self.assertTrue(torch.allclose(ref_output, fb_output)) model.eval() inputs = model.get_example_inputs() @@ -445,27 +461,41 @@ class TestOnDeviceDynamicPTQFinalize(TestCase): ref_m = torch.jit.load(buffer) ref_output = ref_m(*inputs) - m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) - buffer = io.BytesIO() - torch.jit.save(m, buffer) - buffer.seek(0) - m = torch.jit.load(buffer) - m.reset_observers_forward() - m.observe_forward(*inputs) - m.quantize_forward(*inputs) - output = m.quantized_forward(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) + if not check_device_side_api: + m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + m = torch.jit.load(buffer) + m.reset_observers_forward() + m.observe_forward(*inputs) + m.quantize_forward(*inputs) + output = m.quantized_forward(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) + else: + # check for lite interpreter + m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) + first_input, = inputs + rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype) + m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )]) + buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + m = _load_for_lite_interpreter(buffer) # Error here + torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward") + self.assertFalse(m.find_method("quantized_forward")) + self.assertFalse(m.find_method("quantize_forward")) + self.assertFalse(m.find_method("observe_forward")) + self.assertFalse(m.find_method("reset_observers_forward")) + output = m(*inputs) + self.assertTrue(torch.allclose(ref_output, output)) - # check for lite interpreter - m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) - buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) - buffer.seek(0) - m = _load_for_lite_interpreter(buffer) # Error here - m.run_method("reset_observers_forward") - m.run_method("observe_forward", *inputs) - m.run_method("quantize_forward", *inputs) - output = m.run_method("quantized_forward", *inputs) - self.assertTrue(torch.allclose(ref_output, output)) + + def _check_serialization_deserialization(self, model): + self._check_serdes_and_device_side_api_helper(model, False) + + + def _check_device_side_api(self, model): + self._check_serdes_and_device_side_api_helper(model, True) def test_quantize_forward(self): @@ -492,3 +522,8 @@ class TestOnDeviceDynamicPTQFinalize(TestCase): def test_serialization_deserialization(self): model = MyConvLinearModule() self._check_serialization_deserialization(model) + + + def test_device_side_api(self): + model = MyConvLinearModule() + self._check_device_side_api(model) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 92a3631f2d9..024c5ae394f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -298,6 +298,7 @@ def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ... def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ... def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ... def _export_operator_list(module: LiteScriptModule): ... +def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ... def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ... def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ... def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ... diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 2ef7c34c28b..5da8cb4a55d 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -43,6 +43,50 @@ Method Module::get_method(const std::string& name) const { AT_ERROR("Method '", name, "' is not defined."); } +bool Module::compareMethodSchemas( + const std::string& name_1, + const std::string& name_2) { + c10::optional schema_1, schema_2; + for (const auto& fn : cu_->methods()) { + if (fn->name() == name_1) { + schema_1 = fn->getSchema(); + } + if (fn->name() == name_2) { + schema_2 = fn->getSchema(); + } + } + if (schema_1.has_value() && schema_2.has_value()) { + return (schema_1 == schema_2); + } + return false; +} + +void Module::unsafeRemoveMethod(const std::string& basename) { + int64_t i = 0; + for (; i < cu_->methods().size(); ++i) { + if ((cu_->methods()[i])->name() == basename) { + break; + } + } + object_->type()->unsafeRemoveMethod(basename); + cu_->unsafeRemoveFunction(i); +} + +void Module::unsafeCopyMethod( + const std::string& new_method_name, + const Function& to_be_copied) { + TORCH_CHECK( + !find_method(new_method_name).has_value(), + "Trying to replace existing method."); + const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname(); + c10::QualifiedName qualified_method_name( + tobe_copied_name.prefix(), new_method_name); + std::unique_ptr new_fn = std::make_unique( + qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema()); + object_->type()->addMethod(new_fn.get()); + cu_->register_function(std::move(new_fn)); +} + c10::optional Module::find_method(const std::string& basename) const { for (const auto& fn : cu_->methods()) { if (fn->name() == basename) { diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 01c76e14658..2b07831e673 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -42,6 +43,10 @@ class CompilationUnit { Function* find_function(const c10::QualifiedName& qn); const Function* find_function(const c10::QualifiedName& qn) const; + void unsafeRemoveFunction(const int64_t index) { + methods_.erase(methods_.begin() + index); + } + private: std::vector> methods_; }; @@ -71,6 +76,7 @@ class TORCH_API Module { return get_method("forward")(std::move(inputs)); } c10::optional find_method(const std::string& basename) const; + const std::string name() const { return object_->name(); } @@ -152,6 +158,18 @@ class TORCH_API Module { } private: + friend class quantization::PTQQuanizationHelper; + + bool compareMethodSchemas( + const std::string& name_1, + const std::string& name_2); + + void unsafeRemoveMethod(const std::string& basename); + + void unsafeCopyMethod( + const std::string& new_method_name, + const Function& to_be_copied); + c10::intrusive_ptr object_; std::unordered_map metadata_; std::shared_ptr cu_; diff --git a/torch/csrc/jit/mobile/quantization.cpp b/torch/csrc/jit/mobile/quantization.cpp new file mode 100644 index 00000000000..b391cf5ac0e --- /dev/null +++ b/torch/csrc/jit/mobile/quantization.cpp @@ -0,0 +1,66 @@ +#include +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace quantization { + +void PTQQuanizationHelper::quantize_dynamic( + torch::jit::mobile::Module& m, + const std::string& method_name) { + at::globalContext().setReleaseWeightsWhenPrepacking(false); + std::string reset_observers_method_name = "reset_observers_" + method_name; + std::string observe_method_name = "observe_" + method_name; + std::string quantize_method_name = "quantize_" + method_name; + std::string quantized_method_name = "quantized_" + method_name; + + TORCH_CHECK( + m.find_method(reset_observers_method_name).has_value(), + "PTQ ready module must have", + reset_observers_method_name, + " method."); + TORCH_CHECK( + m.find_method(observe_method_name), + "PTQ ready module must have", + reset_observers_method_name, + " method."); + TORCH_CHECK( + m.find_method(quantize_method_name), + "PTQ ready module must have", + quantize_method_name, + " method."); + TORCH_CHECK( + m.find_method(quantized_method_name), + "PTQ ready module must have", + quantized_method_name, + " method."); + TORCH_CHECK( + m.find_method("get_all_bundled_inputs"), + "PTQ ready module must have get_all_bundled_inputs method."); + + auto inputs = m.run_method("get_all_bundled_inputs") + .toList() + .get(0) + .toTupleRef() + .elements() + .vec(); + m.get_method(reset_observers_method_name)({}); + m.get_method(observe_method_name)(inputs); + m.get_method(quantize_method_name)(inputs); + + m.compareMethodSchemas(method_name, quantized_method_name); + m.unsafeRemoveMethod(method_name); + const Function& to_be_copied = + m.find_method(quantized_method_name).value().function(); + m.unsafeCopyMethod(method_name, to_be_copied); + m.unsafeRemoveMethod(quantized_method_name); + m.unsafeRemoveMethod(quantize_method_name); + m.unsafeRemoveMethod(observe_method_name); + m.unsafeRemoveMethod(reset_observers_method_name); +} +} // namespace quantization +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/quantization.h b/torch/csrc/jit/mobile/quantization.h new file mode 100644 index 00000000000..aa47dcb64b6 --- /dev/null +++ b/torch/csrc/jit/mobile/quantization.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +class Module; +namespace quantization { +/* + * Device side PTQ API. + * Once the model has been prepared for quantization on server side, such model + * is sent to device. On device side the model is further trained. At the end of + * the training, before the model is readied for inference, we need to quantize + * the model. + * Usage of this API is as follows. + * PTQQuanizationHelper ptq_helper; + * ptq_helper.quantize_dynamic(m, "forward"); + * Args: + * m: Captured by reference, an instance of mobile::Module. This module will be + * mutated in place to replace its method with quantized + * equivalent. method:name: Name of the method to be quantized. AOT preparation + * for quantization must also have been done for this method. Returns: In place + * mutated `m` whose size should be smaller due to weight quantization and whose + * method should use quantized ops + */ +class TORCH_API PTQQuanizationHelper { + public: + PTQQuanizationHelper() = default; + void quantize_dynamic( + torch::jit::mobile::Module& m, + const std::string& method_name); +}; +} // namespace quantization +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 3f825d9bd52..110c2f4a70c 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -1953,6 +1954,12 @@ void initJitScriptBindings(PyObject* module) { m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) { return debugMakeSet(torch::jit::mobile::_export_operator_list(sm)); }); + m.def( + "_quantize_ondevice_ptq_dynamic", + [](mobile::Module& m, const std::string& method_name) { + mobile::quantization::PTQQuanizationHelper ptq_helper; + ptq_helper.quantize_dynamic(m, method_name); + }); m.def("_jit_set_emit_hooks", setEmitHooks); m.def("_jit_get_emit_hooks", getEmitHooks);