[AOTInductor] Option to not include weight in .so (#141997)

Summary: Add an option in config to not include weights in .so

Test Plan: `test/inductor:test_aot_inductor -- -r test_so_without_weight_cuda`

Reviewed By: desertfire

Differential Revision: D65968885

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141997
Approved by: https://github.com/desertfire
This commit is contained in:
Mu-Chu Lee 2024-12-05 03:35:54 +00:00 committed by PyTorch MergeBot
parent 51cbac4e6a
commit b08bc07cd7
6 changed files with 105 additions and 10 deletions

View file

@ -3866,6 +3866,76 @@ class AOTInductorTestsTemplate:
example_inputs = (torch.randn(2, 128, 4096, device=self.device),)
self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}})
def test_so_without_weight(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.weight = torch.randn(n, k, device=device)
self.bias = torch.randn(n, device=device)
def forward(self, a):
return torch.nn.functional.linear(a, self.weight, self.bias)
M, N, K = 128, 2048, 4096
model = Model(N, K, self.device)
a = torch.randn(M, K, device=self.device)
example_inputs = (a,)
with torch.no_grad(), config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": True,
}
):
so_path = AOTIRunnerUtil.compile(
model=model,
example_inputs=example_inputs,
)
with torch.no_grad(), config.patch(
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
}
):
so_path_weightless = AOTIRunnerUtil.compile(
model=model,
example_inputs=example_inputs,
)
self.assertTrue(os.path.getsize(so_path) > 10_000_000)
self.assertTrue(os.path.getsize(so_path_weightless) < 10_000_000)
runner = AOTIRunnerUtil.load_runner(self.device, so_path_weightless)
# Let's check whether the model has correct constant name mapping.
expected_original_fqns = {
"L__self___weight": "L__self___weight",
"L__self___bias": "L__self___bias",
}
self.assertEqual(
expected_original_fqns, runner.get_constant_names_to_original_fqns()
)
def runner_call(*args, **kwargs):
import torch.fx._pytree as fx_pytree
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
test_inputs = torch.randn(M, K, device=self.device)
attach_weights = {
"L__self___weight": model.weight,
"L__self___bias": model.bias,
}
runner.update_constant_buffer(attach_weights, False, False)
expected = model(test_inputs)
output = runner_call(test_inputs)
self.assertEqual(expected, output)
def test_update_constant_buffer(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
@ -3987,6 +4057,7 @@ CPU_TEST_FAILURES = {
# TODO: failed internally
"test_multiple_output_alias": fail_cpu(is_skip=True),
"test_update_constant_buffer": fail_cpu(is_skip=True),
"test_so_without_weight": fail_cpu(is_skip=True),
}
# test_failures, xfail by default, set is_skip=True to skip

View file

@ -169,6 +169,7 @@ CPU_TEST_FAILURES = {
"test_symbool_item": fail_minimal_arrayref_interface(is_skip=True),
"test_issue_140766": fail_minimal_arrayref_interface(),
"test_update_constant_buffer": fail_stack_allocation(is_skip=True),
"test_so_without_weight": fail_stack_allocation(is_skip=True),
}

View file

@ -1637,11 +1637,15 @@ class AotCodeCompiler:
raw_bytes = bytes(raw_array.contents)
return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)
serialized_weights = b"".join(
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
for name in graph.constants.keys()
if name not in graph.folded_constants
)
if config.aot_inductor.package_constants_in_so:
serialized_weights = b"".join(
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
for name in graph.constants.keys()
if name not in graph.folded_constants
)
else:
serialized_weights = b""
consts_size = len(serialized_weights)
# TODO: Fix mmap weights with cuda
@ -1711,6 +1715,7 @@ class AotCodeCompiler:
aot_mode=graph.aot_mode,
use_absolute_path=use_absolute_path,
)
so_builder = CppBuilder(
name=output_name,
sources=[output_o, consts_o, kernels_o],

View file

@ -564,13 +564,17 @@ class CppWrapperCpu(PythonWrapperCodegen):
num_inputs = len(V.graph.graph_inputs)
num_outputs = len(V.graph.graph_outputs)
num_constants = len(V.graph.constants)
include_weights = (
"true" if config.aot_inductor.package_constants_in_so else "false"
)
self.prefix.splice(
f"""
AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
const std::string& device_str,
std::optional<std::string> cubin_dir)
: AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{
std::optional<std::string> cubin_dir,
bool include_weights)
: AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir, {include_weights}) {{
"""
)

View file

@ -1118,6 +1118,9 @@ class aot_inductor:
# but performance for that interface may be degraded.
use_minimal_arrayref_interface: bool = False
# Experimental. Flag to control whether to include weight in .so
package_constants_in_so: bool = True
class cuda:
# CUDA arch to use for CUDA template kernel compilation.

View file

@ -107,11 +107,13 @@ class AOTInductorModelBase {
size_t num_outputs,
size_t num_constants,
const std::string& device_str,
std::optional<std::string> cubin_dir)
std::optional<std::string> cubin_dir,
bool include_weights = true)
: inputs_info_(num_inputs),
outputs_info_(num_outputs),
constants_info_(num_constants),
cubin_dir_(std::move(cubin_dir)) {
cubin_dir_(std::move(cubin_dir)),
include_weights(include_weights) {
parse_device_str(device_str, device_type_, device_idx_);
#ifdef USE_CUDA
@ -209,6 +211,9 @@ class AOTInductorModelBase {
constant_blob_ = RAII_cudaMalloc(blob_size);
#endif
}
if (!include_weights) {
return;
}
size_t bytes_read = 0;
for (size_t i = 0; i < num_constants; i++) {
@ -568,6 +573,11 @@ class AOTInductorModelBase {
// A directory with CUDA binary files, e.g. compiled kernels, etc.
const std::optional<std::string> cubin_dir_;
// This is the flag that implies whether the weight is included in the model.
// If True, we would prepare the weight when loading the model, otherwise the
// model will be loaded without weights, and need to be provided by the user.
bool include_weights;
// Record if the model finishes an inference run so that its owning
// AOTModelContainer can re-use this instance.
#ifdef USE_CUDA
@ -593,7 +603,8 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
std::shared_ptr<ConstantMap> constants_map,
std::shared_ptr<std::vector<ConstantHandle>> constants_array,
const std::string& device_str,
std::optional<std::string> cubin_dir);
std::optional<std::string> cubin_dir,
bool include_weights = true);
std::unordered_map<std::string, AtenTensorHandle> const_run_impl(
DeviceStreamType stream,