mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[AOTInductor] Option to not include weight in .so (#141997)
Summary: Add an option in config to not include weights in .so Test Plan: `test/inductor:test_aot_inductor -- -r test_so_without_weight_cuda` Reviewed By: desertfire Differential Revision: D65968885 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141997 Approved by: https://github.com/desertfire
This commit is contained in:
parent
51cbac4e6a
commit
b08bc07cd7
6 changed files with 105 additions and 10 deletions
|
|
@ -3866,6 +3866,76 @@ class AOTInductorTestsTemplate:
|
|||
example_inputs = (torch.randn(2, 128, 4096, device=self.device),)
|
||||
self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}})
|
||||
|
||||
def test_so_without_weight(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, n, k, device):
|
||||
super().__init__()
|
||||
self.weight = torch.randn(n, k, device=device)
|
||||
self.bias = torch.randn(n, device=device)
|
||||
|
||||
def forward(self, a):
|
||||
return torch.nn.functional.linear(a, self.weight, self.bias)
|
||||
|
||||
M, N, K = 128, 2048, 4096
|
||||
model = Model(N, K, self.device)
|
||||
a = torch.randn(M, K, device=self.device)
|
||||
example_inputs = (a,)
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": True,
|
||||
}
|
||||
):
|
||||
so_path = AOTIRunnerUtil.compile(
|
||||
model=model,
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
|
||||
with torch.no_grad(), config.patch(
|
||||
{
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
}
|
||||
):
|
||||
so_path_weightless = AOTIRunnerUtil.compile(
|
||||
model=model,
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
self.assertTrue(os.path.getsize(so_path) > 10_000_000)
|
||||
self.assertTrue(os.path.getsize(so_path_weightless) < 10_000_000)
|
||||
|
||||
runner = AOTIRunnerUtil.load_runner(self.device, so_path_weightless)
|
||||
|
||||
# Let's check whether the model has correct constant name mapping.
|
||||
expected_original_fqns = {
|
||||
"L__self___weight": "L__self___weight",
|
||||
"L__self___bias": "L__self___bias",
|
||||
}
|
||||
self.assertEqual(
|
||||
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
||||
)
|
||||
|
||||
def runner_call(*args, **kwargs):
|
||||
import torch.fx._pytree as fx_pytree
|
||||
|
||||
call_spec = runner.get_call_spec()
|
||||
in_spec = pytree.treespec_loads(call_spec[0])
|
||||
out_spec = pytree.treespec_loads(call_spec[1])
|
||||
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
||||
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
||||
flat_outputs = runner.run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
test_inputs = torch.randn(M, K, device=self.device)
|
||||
attach_weights = {
|
||||
"L__self___weight": model.weight,
|
||||
"L__self___bias": model.bias,
|
||||
}
|
||||
runner.update_constant_buffer(attach_weights, False, False)
|
||||
expected = model(test_inputs)
|
||||
output = runner_call(test_inputs)
|
||||
self.assertEqual(expected, output)
|
||||
|
||||
def test_update_constant_buffer(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, n, k, device):
|
||||
|
|
@ -3987,6 +4057,7 @@ CPU_TEST_FAILURES = {
|
|||
# TODO: failed internally
|
||||
"test_multiple_output_alias": fail_cpu(is_skip=True),
|
||||
"test_update_constant_buffer": fail_cpu(is_skip=True),
|
||||
"test_so_without_weight": fail_cpu(is_skip=True),
|
||||
}
|
||||
|
||||
# test_failures, xfail by default, set is_skip=True to skip
|
||||
|
|
|
|||
|
|
@ -169,6 +169,7 @@ CPU_TEST_FAILURES = {
|
|||
"test_symbool_item": fail_minimal_arrayref_interface(is_skip=True),
|
||||
"test_issue_140766": fail_minimal_arrayref_interface(),
|
||||
"test_update_constant_buffer": fail_stack_allocation(is_skip=True),
|
||||
"test_so_without_weight": fail_stack_allocation(is_skip=True),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1637,11 +1637,15 @@ class AotCodeCompiler:
|
|||
raw_bytes = bytes(raw_array.contents)
|
||||
return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)
|
||||
|
||||
serialized_weights = b"".join(
|
||||
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
|
||||
for name in graph.constants.keys()
|
||||
if name not in graph.folded_constants
|
||||
)
|
||||
if config.aot_inductor.package_constants_in_so:
|
||||
serialized_weights = b"".join(
|
||||
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
|
||||
for name in graph.constants.keys()
|
||||
if name not in graph.folded_constants
|
||||
)
|
||||
else:
|
||||
serialized_weights = b""
|
||||
|
||||
consts_size = len(serialized_weights)
|
||||
|
||||
# TODO: Fix mmap weights with cuda
|
||||
|
|
@ -1711,6 +1715,7 @@ class AotCodeCompiler:
|
|||
aot_mode=graph.aot_mode,
|
||||
use_absolute_path=use_absolute_path,
|
||||
)
|
||||
|
||||
so_builder = CppBuilder(
|
||||
name=output_name,
|
||||
sources=[output_o, consts_o, kernels_o],
|
||||
|
|
|
|||
|
|
@ -564,13 +564,17 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
num_inputs = len(V.graph.graph_inputs)
|
||||
num_outputs = len(V.graph.graph_outputs)
|
||||
num_constants = len(V.graph.constants)
|
||||
include_weights = (
|
||||
"true" if config.aot_inductor.package_constants_in_so else "false"
|
||||
)
|
||||
self.prefix.splice(
|
||||
f"""
|
||||
AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
|
||||
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
|
||||
const std::string& device_str,
|
||||
std::optional<std::string> cubin_dir)
|
||||
: AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{
|
||||
std::optional<std::string> cubin_dir,
|
||||
bool include_weights)
|
||||
: AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir, {include_weights}) {{
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1118,6 +1118,9 @@ class aot_inductor:
|
|||
# but performance for that interface may be degraded.
|
||||
use_minimal_arrayref_interface: bool = False
|
||||
|
||||
# Experimental. Flag to control whether to include weight in .so
|
||||
package_constants_in_so: bool = True
|
||||
|
||||
|
||||
class cuda:
|
||||
# CUDA arch to use for CUDA template kernel compilation.
|
||||
|
|
|
|||
|
|
@ -107,11 +107,13 @@ class AOTInductorModelBase {
|
|||
size_t num_outputs,
|
||||
size_t num_constants,
|
||||
const std::string& device_str,
|
||||
std::optional<std::string> cubin_dir)
|
||||
std::optional<std::string> cubin_dir,
|
||||
bool include_weights = true)
|
||||
: inputs_info_(num_inputs),
|
||||
outputs_info_(num_outputs),
|
||||
constants_info_(num_constants),
|
||||
cubin_dir_(std::move(cubin_dir)) {
|
||||
cubin_dir_(std::move(cubin_dir)),
|
||||
include_weights(include_weights) {
|
||||
parse_device_str(device_str, device_type_, device_idx_);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -209,6 +211,9 @@ class AOTInductorModelBase {
|
|||
constant_blob_ = RAII_cudaMalloc(blob_size);
|
||||
#endif
|
||||
}
|
||||
if (!include_weights) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bytes_read = 0;
|
||||
for (size_t i = 0; i < num_constants; i++) {
|
||||
|
|
@ -568,6 +573,11 @@ class AOTInductorModelBase {
|
|||
// A directory with CUDA binary files, e.g. compiled kernels, etc.
|
||||
const std::optional<std::string> cubin_dir_;
|
||||
|
||||
// This is the flag that implies whether the weight is included in the model.
|
||||
// If True, we would prepare the weight when loading the model, otherwise the
|
||||
// model will be loaded without weights, and need to be provided by the user.
|
||||
bool include_weights;
|
||||
|
||||
// Record if the model finishes an inference run so that its owning
|
||||
// AOTModelContainer can re-use this instance.
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -593,7 +603,8 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
|
|||
std::shared_ptr<ConstantMap> constants_map,
|
||||
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
|
||||
const std::string& device_str,
|
||||
std::optional<std::string> cubin_dir);
|
||||
std::optional<std::string> cubin_dir,
|
||||
bool include_weights = true);
|
||||
|
||||
std::unordered_map<std::string, AtenTensorHandle> const_run_impl(
|
||||
DeviceStreamType stream,
|
||||
|
|
|
|||
Loading…
Reference in a new issue