mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Pytorch][Ondevice quantization] Add device side API to convert model (#83807)
Summary: This diff adds device side API which will convert the model to its quantized equivalent. THe input model must have been prepared AOT for quantization. API is implemented by: - Running reset obervers - Running observe method - Running quantize method - And replacing method, e.g. forward, with its quantized equivalent. Test Plan: test/quantization/jit/test_ondevice_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38889818](https://our.internmc.facebook.com/intern/diff/D38889818) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83807 Approved by: https://github.com/iseeyuan
This commit is contained in:
parent
eebdcb5a2e
commit
cfd18e105f
10 changed files with 255 additions and 42 deletions
|
|
@ -1417,6 +1417,7 @@ def define_buck_targets(
|
|||
"torch/csrc/autograd/VariableTypeManual.cpp",
|
||||
"torch/csrc/autograd/FunctionsManual.cpp",
|
||||
"torch/csrc/api/src/data/datasets/mnist.cpp",
|
||||
"torch/csrc/jit/mobile/quantization.cpp",
|
||||
"torch/csrc/jit/mobile/train/export_data.cpp",
|
||||
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
|
||||
"torch/csrc/jit/mobile/train/random.cpp",
|
||||
|
|
|
|||
|
|
@ -564,6 +564,7 @@ torch_mobile_core = [
|
|||
"torch/csrc/jit/mobile/observer.cpp",
|
||||
"torch/csrc/jit/mobile/parse_bytecode.cpp",
|
||||
"torch/csrc/jit/mobile/parse_operators.cpp",
|
||||
"torch/csrc/jit/mobile/quantization.cpp",
|
||||
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
|
||||
"torch/csrc/jit/runtime/register_prim_ops.cpp",
|
||||
"torch/csrc/jit/runtime/register_special_ops.cpp",
|
||||
|
|
@ -612,6 +613,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
|
|||
"torch/csrc/jit/mobile/observer.cpp",
|
||||
"torch/csrc/jit/mobile/parse_bytecode.cpp",
|
||||
"torch/csrc/jit/mobile/parse_operators.cpp",
|
||||
"torch/csrc/jit/mobile/quantization.cpp",
|
||||
"torch/csrc/jit/mobile/train/export_data.cpp",
|
||||
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
|
||||
"torch/csrc/jit/mobile/train/random.cpp",
|
||||
|
|
|
|||
|
|
@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/quantization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
import torch._C_flatbuffer
|
||||
|
||||
from torch.ao.quantization import (
|
||||
default_dynamic_qconfig,
|
||||
|
|
@ -22,11 +23,13 @@ from torch.testing._internal.common_quantization import (
|
|||
LinearAddModel,
|
||||
)
|
||||
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.utils import bundled_inputs as bundled_inputs
|
||||
|
||||
import io
|
||||
from typing import Dict
|
||||
|
||||
class myMod(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
|
|
@ -396,7 +399,7 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
|||
self.assertTrue(thrown)
|
||||
|
||||
|
||||
def _check_serialization_deserialization(self, model):
|
||||
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
|
||||
model.eval()
|
||||
inputs = model.get_example_inputs()
|
||||
ref_m = torch.jit.script(model)
|
||||
|
|
@ -410,27 +413,40 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
|||
ref_m = torch.jit.load(buffer)
|
||||
ref_output = ref_m(*inputs)
|
||||
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer)
|
||||
buffer.seek(0)
|
||||
m = torch.jit.load(buffer)
|
||||
m.reset_observers_forward()
|
||||
m.observe_forward(*inputs)
|
||||
m.quantize_forward(*inputs)
|
||||
output = m.quantized_forward(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
if not check_device_side_api:
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer)
|
||||
buffer.seek(0)
|
||||
m = torch.jit.load(buffer)
|
||||
m.reset_observers_forward()
|
||||
m.observe_forward(*inputs)
|
||||
m.quantize_forward(*inputs)
|
||||
output = m.quantized_forward(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
else:
|
||||
# check for lite interpreter
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
first_input, = inputs
|
||||
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
|
||||
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
|
||||
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
m = _load_for_lite_interpreter(buffer) # Error here
|
||||
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
|
||||
self.assertFalse(m.find_method("quantized_forward"))
|
||||
self.assertFalse(m.find_method("quantize_forward"))
|
||||
self.assertFalse(m.find_method("observe_forward"))
|
||||
self.assertFalse(m.find_method("reset_observers_forward"))
|
||||
output = m(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
# check for lite interpreter
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
m = _load_for_lite_interpreter(buffer) # Error here
|
||||
m.run_method("reset_observers_forward")
|
||||
m.run_method("observe_forward", *inputs)
|
||||
m.run_method("quantize_forward", *inputs)
|
||||
output = m.run_method("quantized_forward", *inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
# Now serialize to flabuffer and load from fb and check
|
||||
dict: Dict[str, str] = {}
|
||||
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
|
||||
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
|
||||
fb_output = m(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, fb_output))
|
||||
|
||||
model.eval()
|
||||
inputs = model.get_example_inputs()
|
||||
|
|
@ -445,27 +461,41 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
|||
ref_m = torch.jit.load(buffer)
|
||||
ref_output = ref_m(*inputs)
|
||||
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer)
|
||||
buffer.seek(0)
|
||||
m = torch.jit.load(buffer)
|
||||
m.reset_observers_forward()
|
||||
m.observe_forward(*inputs)
|
||||
m.quantize_forward(*inputs)
|
||||
output = m.quantized_forward(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
if not check_device_side_api:
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer)
|
||||
buffer.seek(0)
|
||||
m = torch.jit.load(buffer)
|
||||
m.reset_observers_forward()
|
||||
m.observe_forward(*inputs)
|
||||
m.quantize_forward(*inputs)
|
||||
output = m.quantized_forward(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
else:
|
||||
# check for lite interpreter
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
first_input, = inputs
|
||||
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
|
||||
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
|
||||
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
m = _load_for_lite_interpreter(buffer) # Error here
|
||||
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
|
||||
self.assertFalse(m.find_method("quantized_forward"))
|
||||
self.assertFalse(m.find_method("quantize_forward"))
|
||||
self.assertFalse(m.find_method("observe_forward"))
|
||||
self.assertFalse(m.find_method("reset_observers_forward"))
|
||||
output = m(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
# check for lite interpreter
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
m = _load_for_lite_interpreter(buffer) # Error here
|
||||
m.run_method("reset_observers_forward")
|
||||
m.run_method("observe_forward", *inputs)
|
||||
m.run_method("quantize_forward", *inputs)
|
||||
output = m.run_method("quantized_forward", *inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
def _check_serialization_deserialization(self, model):
|
||||
self._check_serdes_and_device_side_api_helper(model, False)
|
||||
|
||||
|
||||
def _check_device_side_api(self, model):
|
||||
self._check_serdes_and_device_side_api_helper(model, True)
|
||||
|
||||
|
||||
def test_quantize_forward(self):
|
||||
|
|
@ -492,3 +522,8 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
|||
def test_serialization_deserialization(self):
|
||||
model = MyConvLinearModule()
|
||||
self._check_serialization_deserialization(model)
|
||||
|
||||
|
||||
def test_device_side_api(self):
|
||||
model = MyConvLinearModule()
|
||||
self._check_device_side_api(model)
|
||||
|
|
|
|||
|
|
@ -298,6 +298,7 @@ def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
|
|||
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
|
||||
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
|
||||
def _export_operator_list(module: LiteScriptModule): ...
|
||||
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
|
||||
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
|
||||
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
|
||||
def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...
|
||||
|
|
|
|||
|
|
@ -43,6 +43,50 @@ Method Module::get_method(const std::string& name) const {
|
|||
AT_ERROR("Method '", name, "' is not defined.");
|
||||
}
|
||||
|
||||
bool Module::compareMethodSchemas(
|
||||
const std::string& name_1,
|
||||
const std::string& name_2) {
|
||||
c10::optional<c10::FunctionSchema> schema_1, schema_2;
|
||||
for (const auto& fn : cu_->methods()) {
|
||||
if (fn->name() == name_1) {
|
||||
schema_1 = fn->getSchema();
|
||||
}
|
||||
if (fn->name() == name_2) {
|
||||
schema_2 = fn->getSchema();
|
||||
}
|
||||
}
|
||||
if (schema_1.has_value() && schema_2.has_value()) {
|
||||
return (schema_1 == schema_2);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void Module::unsafeRemoveMethod(const std::string& basename) {
|
||||
int64_t i = 0;
|
||||
for (; i < cu_->methods().size(); ++i) {
|
||||
if ((cu_->methods()[i])->name() == basename) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
object_->type()->unsafeRemoveMethod(basename);
|
||||
cu_->unsafeRemoveFunction(i);
|
||||
}
|
||||
|
||||
void Module::unsafeCopyMethod(
|
||||
const std::string& new_method_name,
|
||||
const Function& to_be_copied) {
|
||||
TORCH_CHECK(
|
||||
!find_method(new_method_name).has_value(),
|
||||
"Trying to replace existing method.");
|
||||
const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
|
||||
c10::QualifiedName qualified_method_name(
|
||||
tobe_copied_name.prefix(), new_method_name);
|
||||
std::unique_ptr<Function> new_fn = std::make_unique<Function>(
|
||||
qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
|
||||
object_->type()->addMethod(new_fn.get());
|
||||
cu_->register_function(std::move(new_fn));
|
||||
}
|
||||
|
||||
c10::optional<Method> Module::find_method(const std::string& basename) const {
|
||||
for (const auto& fn : cu_->methods()) {
|
||||
if (fn->name() == basename) {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <torch/csrc/jit/mobile/debug_info.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/method.h>
|
||||
#include <torch/csrc/jit/mobile/quantization.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -42,6 +43,10 @@ class CompilationUnit {
|
|||
Function* find_function(const c10::QualifiedName& qn);
|
||||
const Function* find_function(const c10::QualifiedName& qn) const;
|
||||
|
||||
void unsafeRemoveFunction(const int64_t index) {
|
||||
methods_.erase(methods_.begin() + index);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Function>> methods_;
|
||||
};
|
||||
|
|
@ -71,6 +76,7 @@ class TORCH_API Module {
|
|||
return get_method("forward")(std::move(inputs));
|
||||
}
|
||||
c10::optional<Method> find_method(const std::string& basename) const;
|
||||
|
||||
const std::string name() const {
|
||||
return object_->name();
|
||||
}
|
||||
|
|
@ -152,6 +158,18 @@ class TORCH_API Module {
|
|||
}
|
||||
|
||||
private:
|
||||
friend class quantization::PTQQuanizationHelper;
|
||||
|
||||
bool compareMethodSchemas(
|
||||
const std::string& name_1,
|
||||
const std::string& name_2);
|
||||
|
||||
void unsafeRemoveMethod(const std::string& basename);
|
||||
|
||||
void unsafeCopyMethod(
|
||||
const std::string& new_method_name,
|
||||
const Function& to_be_copied);
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
||||
std::unordered_map<std::string, std::string> metadata_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
|
|
|
|||
66
torch/csrc/jit/mobile/quantization.cpp
Normal file
66
torch/csrc/jit/mobile/quantization.cpp
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
#include <ATen/Context.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/quantization.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
namespace quantization {
|
||||
|
||||
void PTQQuanizationHelper::quantize_dynamic(
|
||||
torch::jit::mobile::Module& m,
|
||||
const std::string& method_name) {
|
||||
at::globalContext().setReleaseWeightsWhenPrepacking(false);
|
||||
std::string reset_observers_method_name = "reset_observers_" + method_name;
|
||||
std::string observe_method_name = "observe_" + method_name;
|
||||
std::string quantize_method_name = "quantize_" + method_name;
|
||||
std::string quantized_method_name = "quantized_" + method_name;
|
||||
|
||||
TORCH_CHECK(
|
||||
m.find_method(reset_observers_method_name).has_value(),
|
||||
"PTQ ready module must have",
|
||||
reset_observers_method_name,
|
||||
" method.");
|
||||
TORCH_CHECK(
|
||||
m.find_method(observe_method_name),
|
||||
"PTQ ready module must have",
|
||||
reset_observers_method_name,
|
||||
" method.");
|
||||
TORCH_CHECK(
|
||||
m.find_method(quantize_method_name),
|
||||
"PTQ ready module must have",
|
||||
quantize_method_name,
|
||||
" method.");
|
||||
TORCH_CHECK(
|
||||
m.find_method(quantized_method_name),
|
||||
"PTQ ready module must have",
|
||||
quantized_method_name,
|
||||
" method.");
|
||||
TORCH_CHECK(
|
||||
m.find_method("get_all_bundled_inputs"),
|
||||
"PTQ ready module must have get_all_bundled_inputs method.");
|
||||
|
||||
auto inputs = m.run_method("get_all_bundled_inputs")
|
||||
.toList()
|
||||
.get(0)
|
||||
.toTupleRef()
|
||||
.elements()
|
||||
.vec();
|
||||
m.get_method(reset_observers_method_name)({});
|
||||
m.get_method(observe_method_name)(inputs);
|
||||
m.get_method(quantize_method_name)(inputs);
|
||||
|
||||
m.compareMethodSchemas(method_name, quantized_method_name);
|
||||
m.unsafeRemoveMethod(method_name);
|
||||
const Function& to_be_copied =
|
||||
m.find_method(quantized_method_name).value().function();
|
||||
m.unsafeCopyMethod(method_name, to_be_copied);
|
||||
m.unsafeRemoveMethod(quantized_method_name);
|
||||
m.unsafeRemoveMethod(quantize_method_name);
|
||||
m.unsafeRemoveMethod(observe_method_name);
|
||||
m.unsafeRemoveMethod(reset_observers_method_name);
|
||||
}
|
||||
} // namespace quantization
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
38
torch/csrc/jit/mobile/quantization.h
Normal file
38
torch/csrc/jit/mobile/quantization.h
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
class Module;
|
||||
namespace quantization {
|
||||
/*
|
||||
* Device side PTQ API.
|
||||
* Once the model has been prepared for quantization on server side, such model
|
||||
* is sent to device. On device side the model is further trained. At the end of
|
||||
* the training, before the model is readied for inference, we need to quantize
|
||||
* the model.
|
||||
* Usage of this API is as follows.
|
||||
* PTQQuanizationHelper ptq_helper;
|
||||
* ptq_helper.quantize_dynamic(m, "forward");
|
||||
* Args:
|
||||
* m: Captured by reference, an instance of mobile::Module. This module will be
|
||||
* mutated in place to replace its <method_name> method with quantized
|
||||
* equivalent. method:name: Name of the method to be quantized. AOT preparation
|
||||
* for quantization must also have been done for this method. Returns: In place
|
||||
* mutated `m` whose size should be smaller due to weight quantization and whose
|
||||
* <method_name> method should use quantized ops
|
||||
*/
|
||||
class TORCH_API PTQQuanizationHelper {
|
||||
public:
|
||||
PTQQuanizationHelper() = default;
|
||||
void quantize_dynamic(
|
||||
torch::jit::mobile::Module& m,
|
||||
const std::string& method_name);
|
||||
};
|
||||
} // namespace quantization
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -16,6 +16,7 @@
|
|||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/quantization.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/utils.h>
|
||||
|
|
@ -1953,6 +1954,12 @@ void initJitScriptBindings(PyObject* module) {
|
|||
m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
|
||||
return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
|
||||
});
|
||||
m.def(
|
||||
"_quantize_ondevice_ptq_dynamic",
|
||||
[](mobile::Module& m, const std::string& method_name) {
|
||||
mobile::quantization::PTQQuanizationHelper ptq_helper;
|
||||
ptq_helper.quantize_dynamic(m, method_name);
|
||||
});
|
||||
|
||||
m.def("_jit_set_emit_hooks", setEmitHooks);
|
||||
m.def("_jit_get_emit_hooks", getEmitHooks);
|
||||
|
|
|
|||
Loading…
Reference in a new issue