[Pytorch][Ondevice quantization] Add device side API to convert model (#83807)

Summary:
This diff adds device side API which will convert the model to its
quantized equivalent. THe input model must have been prepared AOT for
quantization.

API is implemented by:
- Running reset obervers
- Running observe method
- Running quantize method
- And replacing method, e.g. forward, with its quantized equivalent.

Test Plan:
test/quantization/jit/test_ondevice_quantization.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D38889818](https://our.internmc.facebook.com/intern/diff/D38889818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83807
Approved by: https://github.com/iseeyuan
This commit is contained in:
Kimish Patel 2022-08-27 16:06:16 -07:00 committed by PyTorch MergeBot
parent eebdcb5a2e
commit cfd18e105f
10 changed files with 255 additions and 42 deletions

View file

@ -1417,6 +1417,7 @@ def define_buck_targets(
"torch/csrc/autograd/VariableTypeManual.cpp",
"torch/csrc/autograd/FunctionsManual.cpp",
"torch/csrc/api/src/data/datasets/mnist.cpp",
"torch/csrc/jit/mobile/quantization.cpp",
"torch/csrc/jit/mobile/train/export_data.cpp",
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
"torch/csrc/jit/mobile/train/random.cpp",

View file

@ -564,6 +564,7 @@ torch_mobile_core = [
"torch/csrc/jit/mobile/observer.cpp",
"torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/mobile/parse_operators.cpp",
"torch/csrc/jit/mobile/quantization.cpp",
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
"torch/csrc/jit/runtime/register_prim_ops.cpp",
"torch/csrc/jit/runtime/register_special_ops.cpp",
@ -612,6 +613,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/jit/mobile/observer.cpp",
"torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/mobile/parse_operators.cpp",
"torch/csrc/jit/mobile/quantization.cpp",
"torch/csrc/jit/mobile/train/export_data.cpp",
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
"torch/csrc/jit/mobile/train/random.cpp",

View file

@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/quantization.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp

View file

@ -2,6 +2,7 @@
# Owner(s): ["oncall: quantization"]
import torch
import torch._C_flatbuffer
from torch.ao.quantization import (
default_dynamic_qconfig,
@ -22,11 +23,13 @@ from torch.testing._internal.common_quantization import (
LinearAddModel,
)
from torch.jit.mobile import _load_for_lite_interpreter
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
from torch.testing import FileCheck
from torch.utils import bundled_inputs as bundled_inputs
import io
from typing import Dict
class myMod(torch.nn.Module):
def __init__(self, weight):
@ -396,7 +399,7 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
self.assertTrue(thrown)
def _check_serialization_deserialization(self, model):
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
@ -410,27 +413,40 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
m.run_method("reset_observers_forward")
m.run_method("observe_forward", *inputs)
m.run_method("quantize_forward", *inputs)
output = m.run_method("quantized_forward", *inputs)
self.assertTrue(torch.allclose(ref_output, output))
# Now serialize to flabuffer and load from fb and check
dict: Dict[str, str] = {}
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
fb_output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, fb_output))
model.eval()
inputs = model.get_example_inputs()
@ -445,27 +461,41 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
m.run_method("reset_observers_forward")
m.run_method("observe_forward", *inputs)
m.run_method("quantize_forward", *inputs)
output = m.run_method("quantized_forward", *inputs)
self.assertTrue(torch.allclose(ref_output, output))
def _check_serialization_deserialization(self, model):
self._check_serdes_and_device_side_api_helper(model, False)
def _check_device_side_api(self, model):
self._check_serdes_and_device_side_api_helper(model, True)
def test_quantize_forward(self):
@ -492,3 +522,8 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
def test_serialization_deserialization(self):
model = MyConvLinearModule()
self._check_serialization_deserialization(model)
def test_device_side_api(self):
model = MyConvLinearModule()
self._check_device_side_api(model)

View file

@ -298,6 +298,7 @@ def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
def _export_operator_list(module: LiteScriptModule): ...
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...

View file

@ -43,6 +43,50 @@ Method Module::get_method(const std::string& name) const {
AT_ERROR("Method '", name, "' is not defined.");
}
bool Module::compareMethodSchemas(
const std::string& name_1,
const std::string& name_2) {
c10::optional<c10::FunctionSchema> schema_1, schema_2;
for (const auto& fn : cu_->methods()) {
if (fn->name() == name_1) {
schema_1 = fn->getSchema();
}
if (fn->name() == name_2) {
schema_2 = fn->getSchema();
}
}
if (schema_1.has_value() && schema_2.has_value()) {
return (schema_1 == schema_2);
}
return false;
}
void Module::unsafeRemoveMethod(const std::string& basename) {
int64_t i = 0;
for (; i < cu_->methods().size(); ++i) {
if ((cu_->methods()[i])->name() == basename) {
break;
}
}
object_->type()->unsafeRemoveMethod(basename);
cu_->unsafeRemoveFunction(i);
}
void Module::unsafeCopyMethod(
const std::string& new_method_name,
const Function& to_be_copied) {
TORCH_CHECK(
!find_method(new_method_name).has_value(),
"Trying to replace existing method.");
const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
c10::QualifiedName qualified_method_name(
tobe_copied_name.prefix(), new_method_name);
std::unique_ptr<Function> new_fn = std::make_unique<Function>(
qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
object_->type()->addMethod(new_fn.get());
cu_->register_function(std::move(new_fn));
}
c10::optional<Method> Module::find_method(const std::string& basename) const {
for (const auto& fn : cu_->methods()) {
if (fn->name() == basename) {

View file

@ -3,6 +3,7 @@
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/method.h>
#include <torch/csrc/jit/mobile/quantization.h>
namespace torch {
namespace jit {
@ -42,6 +43,10 @@ class CompilationUnit {
Function* find_function(const c10::QualifiedName& qn);
const Function* find_function(const c10::QualifiedName& qn) const;
void unsafeRemoveFunction(const int64_t index) {
methods_.erase(methods_.begin() + index);
}
private:
std::vector<std::unique_ptr<Function>> methods_;
};
@ -71,6 +76,7 @@ class TORCH_API Module {
return get_method("forward")(std::move(inputs));
}
c10::optional<Method> find_method(const std::string& basename) const;
const std::string name() const {
return object_->name();
}
@ -152,6 +158,18 @@ class TORCH_API Module {
}
private:
friend class quantization::PTQQuanizationHelper;
bool compareMethodSchemas(
const std::string& name_1,
const std::string& name_2);
void unsafeRemoveMethod(const std::string& basename);
void unsafeCopyMethod(
const std::string& new_method_name,
const Function& to_be_copied);
c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;

View file

@ -0,0 +1,66 @@
#include <ATen/Context.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/quantization.h>
namespace torch {
namespace jit {
namespace mobile {
namespace quantization {
void PTQQuanizationHelper::quantize_dynamic(
torch::jit::mobile::Module& m,
const std::string& method_name) {
at::globalContext().setReleaseWeightsWhenPrepacking(false);
std::string reset_observers_method_name = "reset_observers_" + method_name;
std::string observe_method_name = "observe_" + method_name;
std::string quantize_method_name = "quantize_" + method_name;
std::string quantized_method_name = "quantized_" + method_name;
TORCH_CHECK(
m.find_method(reset_observers_method_name).has_value(),
"PTQ ready module must have",
reset_observers_method_name,
" method.");
TORCH_CHECK(
m.find_method(observe_method_name),
"PTQ ready module must have",
reset_observers_method_name,
" method.");
TORCH_CHECK(
m.find_method(quantize_method_name),
"PTQ ready module must have",
quantize_method_name,
" method.");
TORCH_CHECK(
m.find_method(quantized_method_name),
"PTQ ready module must have",
quantized_method_name,
" method.");
TORCH_CHECK(
m.find_method("get_all_bundled_inputs"),
"PTQ ready module must have get_all_bundled_inputs method.");
auto inputs = m.run_method("get_all_bundled_inputs")
.toList()
.get(0)
.toTupleRef()
.elements()
.vec();
m.get_method(reset_observers_method_name)({});
m.get_method(observe_method_name)(inputs);
m.get_method(quantize_method_name)(inputs);
m.compareMethodSchemas(method_name, quantized_method_name);
m.unsafeRemoveMethod(method_name);
const Function& to_be_copied =
m.find_method(quantized_method_name).value().function();
m.unsafeCopyMethod(method_name, to_be_copied);
m.unsafeRemoveMethod(quantized_method_name);
m.unsafeRemoveMethod(quantize_method_name);
m.unsafeRemoveMethod(observe_method_name);
m.unsafeRemoveMethod(reset_observers_method_name);
}
} // namespace quantization
} // namespace mobile
} // namespace jit
} // namespace torch

View file

@ -0,0 +1,38 @@
#pragma once
#include <c10/macros/Export.h>
#include <string>
namespace torch {
namespace jit {
namespace mobile {
class Module;
namespace quantization {
/*
* Device side PTQ API.
* Once the model has been prepared for quantization on server side, such model
* is sent to device. On device side the model is further trained. At the end of
* the training, before the model is readied for inference, we need to quantize
* the model.
* Usage of this API is as follows.
* PTQQuanizationHelper ptq_helper;
* ptq_helper.quantize_dynamic(m, "forward");
* Args:
* m: Captured by reference, an instance of mobile::Module. This module will be
* mutated in place to replace its <method_name> method with quantized
* equivalent. method:name: Name of the method to be quantized. AOT preparation
* for quantization must also have been done for this method. Returns: In place
* mutated `m` whose size should be smaller due to weight quantization and whose
* <method_name> method should use quantized ops
*/
class TORCH_API PTQQuanizationHelper {
public:
PTQQuanizationHelper() = default;
void quantize_dynamic(
torch::jit::mobile::Module& m,
const std::string& method_name);
};
} // namespace quantization
} // namespace mobile
} // namespace jit
} // namespace torch

View file

@ -16,6 +16,7 @@
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/quantization.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/operator_upgraders/utils.h>
@ -1953,6 +1954,12 @@ void initJitScriptBindings(PyObject* module) {
m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
});
m.def(
"_quantize_ondevice_ptq_dynamic",
[](mobile::Module& m, const std::string& method_name) {
mobile::quantization::PTQQuanizationHelper ptq_helper;
ptq_helper.quantize_dynamic(m, method_name);
});
m.def("_jit_set_emit_hooks", setEmitHooks);
m.def("_jit_get_emit_hooks", getEmitHooks);