pytorch/test/inductor/test_binary_folding.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

351 lines
12 KiB
Python
Raw Permalink Normal View History

# Owner(s): ["module: inductor"]
import functools
import importlib
import itertools
import os
import sys
import torch
from torch import nn
from torch._dynamo.utils import counters
[Inductor][FX]Support efficient conv bn eval (#108757) This PR adds an `efficient_conv_bn_eval_graph_transform` pass to the inductor. It tries to identify consecutive conv + bn **computation** with bn in eval mode, and changes it to a more efficient implementation. It does not modify parameters, which makes it **support training** without any pain. If no such patterns are identified, it does nothing. Therefore, it is backward compatible. It has great benefit in terms of memory footprint: For resnet50 with input batchsize 64, image size 224, forward + backward training: | Technique | Memory Footprint (GB) | Remarks | |-------------------------------|----------------------------|-------------------------------------------| | Eager Mode | 5.18 | | | torch.compile | 5.46 | Strangely, not saving memory | | torch.compile with this PR | 2.88 | **Saves about 50% memory! ** | The script to measure the memory footprint: ```python from torchvision.models.resnet import resnet50 import torch net = resnet50().eval().cuda() input = torch.randn(64, 3, 224, 224).cuda() opt_net = torch.compile(net) # Use torch.compile # opt_net = net # Eager mode current_memory = torch.cuda.memory_allocated() torch.cuda.reset_peak_memory_stats() for i in range(10): opt_net.zero_grad() output = opt_net(input) output.sum().backward() del output peak_memory = torch.cuda.max_memory_allocated() additional_peak_memory = peak_memory - current_memory print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB") ``` More results can be found in the corresponding paper: (this method is called Tune Mode in the tables). <img width="709" alt="image" src="https://github.com/pytorch/pytorch/assets/23236638/db4815b0-d93e-4726-b1d5-e6651f256484"> <img width="653" alt="image" src="https://github.com/pytorch/pytorch/assets/23236638/22e5e1ab-6129-4c3d-a875-3c7343293b2e"> Note: the difference between this PR and https://github.com/pytorch/pytorch/pull/106372 is that, https://github.com/pytorch/pytorch/pull/106372 tries to fix and change the implementation of `torch.fx.experimental.optimization.fuse`, which causes compatibility issues; this PR only introduces a new graph transform passes, and does not break the previous code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108757 Approved by: https://github.com/jansel
2023-09-20 08:09:55 +00:00
from torch._inductor import config as inductor_config
from torch.testing._internal.common_cuda import TEST_CUDNN
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from inductor.test_inductor_freezing import ( # @manual=fbcode//caffe2/test/inductor:inductor_freezing-library
TestCase,
Fix lint errors in fbcode (#135614) Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps. After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports. Test Plan: ``` fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS ``` Before: ``` ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$ ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib. Some things to try: ``` Differential Revision: D62049222 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614 Approved by: https://github.com/oulgen, https://github.com/laithsakka
2024-09-13 02:04:34 +00:00
)
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
check_model,
check_model_gpu,
copy_tests,
)
from torch.testing._internal.common_utils import TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import skipCUDAIf
importlib.import_module("functorch")
importlib.import_module("filelock")
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
aten = torch.ops.aten
class BinaryFoldingTemplate(TestCase):
@skipCUDAIf(TEST_CUDNN, "CUDNN has accuracy issues for this test")
def test_conv_binary_folding(self):
@torch.no_grad()
def test_conv_fusion(use_bias, module, op, scalar, add_tensor, expect_success):
class ConvOp(nn.Module):
__constants__ = ["use_scalar"]
def __init__(self, in_channels, out_channels, device, **kwargs):
super().__init__()
self.conv = module(
in_channels, out_channels, bias=use_bias, **kwargs
).to(device)
self.use_scalar = scalar
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
tensor_size[1] = self.conv.weight.size(0)
self.tensor = torch.nn.Parameter(
add_tensor
if add_tensor is not None
else torch.rand(tensor_size).to(device)
)
self.op = op
def forward(self, x):
x = self.conv(x)
if self.use_scalar:
return self.op(x, 2.0)
else:
return self.op(x, self.tensor)
torch._dynamo.reset()
counters.clear()
mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval()
out_optimized = torch.compile(mod_eager)
inps = [4, 3, 4]
if module == nn.Conv2d:
inps.append(inps[-1])
if module == nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
torch.manual_seed(1234)
inp = torch.rand(inps).to(self.device)
out_eager = mod_eager(inp)
out_optimized = out_optimized(inp)
self.assertEqual(out_optimized, out_eager)
if expect_success:
self.assertEqual(counters["inductor"]["binary_folding"], 1)
else:
self.assertEqual(counters["inductor"]["binary_folding"], 0)
conv_bias = [True, False]
modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
use_scalar = [True, False]
ops = [torch.add, torch.sub, torch.mul, torch.div]
for use_bias, module, pytorch_op, scalar in itertools.product(
conv_bias, modules, ops, use_scalar
):
test_conv_fusion(
use_bias,
module,
pytorch_op,
scalar,
add_tensor=None,
expect_success=True,
)
for use_bias, pytorch_op in itertools.product(conv_bias, ops):
# broadcasting add
test_conv_fusion(
use_bias,
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.rand(
32,
1,
32,
).to(self.device),
expect_success=False,
)
# broadcasting add
test_conv_fusion(
use_bias,
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.rand(1, 1).to(self.device),
expect_success=True,
)
# add with different dtype
test_conv_fusion(
use_bias,
nn.Conv2d,
pytorch_op,
False,
add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
expect_success=False,
)
[Inductor][FX]Support efficient conv bn eval (#108757) This PR adds an `efficient_conv_bn_eval_graph_transform` pass to the inductor. It tries to identify consecutive conv + bn **computation** with bn in eval mode, and changes it to a more efficient implementation. It does not modify parameters, which makes it **support training** without any pain. If no such patterns are identified, it does nothing. Therefore, it is backward compatible. It has great benefit in terms of memory footprint: For resnet50 with input batchsize 64, image size 224, forward + backward training: | Technique | Memory Footprint (GB) | Remarks | |-------------------------------|----------------------------|-------------------------------------------| | Eager Mode | 5.18 | | | torch.compile | 5.46 | Strangely, not saving memory | | torch.compile with this PR | 2.88 | **Saves about 50% memory! ** | The script to measure the memory footprint: ```python from torchvision.models.resnet import resnet50 import torch net = resnet50().eval().cuda() input = torch.randn(64, 3, 224, 224).cuda() opt_net = torch.compile(net) # Use torch.compile # opt_net = net # Eager mode current_memory = torch.cuda.memory_allocated() torch.cuda.reset_peak_memory_stats() for i in range(10): opt_net.zero_grad() output = opt_net(input) output.sum().backward() del output peak_memory = torch.cuda.max_memory_allocated() additional_peak_memory = peak_memory - current_memory print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB") ``` More results can be found in the corresponding paper: (this method is called Tune Mode in the tables). <img width="709" alt="image" src="https://github.com/pytorch/pytorch/assets/23236638/db4815b0-d93e-4726-b1d5-e6651f256484"> <img width="653" alt="image" src="https://github.com/pytorch/pytorch/assets/23236638/22e5e1ab-6129-4c3d-a875-3c7343293b2e"> Note: the difference between this PR and https://github.com/pytorch/pytorch/pull/106372 is that, https://github.com/pytorch/pytorch/pull/106372 tries to fix and change the implementation of `torch.fx.experimental.optimization.fuse`, which causes compatibility issues; this PR only introduces a new graph transform passes, and does not break the previous code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108757 Approved by: https://github.com/jansel
2023-09-20 08:09:55 +00:00
@inductor_config.patch({"freezing": True})
def test_conv_bn_folding(self):
@torch.no_grad()
def test_conv_fusion(use_bias, module, expect_success):
class ConvOp(nn.Module):
def __init__(self, in_channels, out_channels, device, **kwargs):
super().__init__()
self.conv = module[0](
in_channels, out_channels, bias=use_bias, **kwargs
).to(device)
self.bn = module[1](out_channels).to(device)
def forward(self, x):
x = self.conv(x)
return self.bn(x)
from torch._inductor.compile_fx import compile_fx, compile_fx_inner
aten_binary = [
aten.add.Tensor,
aten.sub.Tensor,
aten.mul.Tensor,
aten.div.Tensor,
]
n_binary_ops = 0
def my_inner_compile(gm, example_inputs, *args, **kwargs):
out = compile_fx_inner(gm, example_inputs, *args, **kwargs)
nonlocal n_binary_ops
binarry_ops = [n for n in gm.graph.nodes if n.target in aten_binary]
n_binary_ops += len(binarry_ops)
return out
torch._dynamo.reset()
mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval()
out_optimized = torch.compile(
mod_eager,
backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
)
inps = [4, 3, 4]
if module[0] == nn.Conv2d:
inps.append(inps[-1])
if module[0] == nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
inp = torch.rand(inps).to(self.device)
out_eager = mod_eager(inp)
out_optimized = out_optimized(inp)
self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5)
if expect_success:
self.assertTrue(n_binary_ops == 0)
else:
self.assertTrue(n_binary_ops > 1)
conv_bias = [True, False]
modules = [
(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d),
]
for use_bias, module in itertools.product(conv_bias, modules):
test_conv_fusion(
use_bias,
module,
expect_success=True,
)
@inductor_config.patch({"enable_linear_binary_folding": True})
def test_linear_binary_folding(self):
@torch.no_grad()
def test_linear_fusion(
use_bias, op, scalar, add_tensor, expect_success, input_3d=False
):
class LinearOp(nn.Module):
__constants__ = ["use_scalar"]
def __init__(self, in_channels, out_channels, device, **kwargs):
super().__init__()
self.linear = nn.Linear(
in_channels, out_channels, bias=use_bias, **kwargs
).to(device)
self.use_scalar = scalar
tensor_size = [
self.linear.weight.size(0),
]
self.tensor = torch.nn.Parameter(
add_tensor
if add_tensor is not None
else torch.rand(tensor_size).to(device)
)
self.op = op
def forward(self, x):
x = self.linear(x)
if self.use_scalar:
return self.op(x, 2.0)
else:
return self.op(x, self.tensor)
torch._dynamo.reset()
counters.clear()
mod_eager = LinearOp(3, 32, self.device).eval()
out_optimized = torch.compile(mod_eager)
torch.manual_seed(1234)
if input_3d:
inp = torch.rand([2, 4, 3]).to(self.device)
else:
inp = torch.rand([4, 3]).to(self.device)
out_eager = mod_eager(inp)
out_optimized = out_optimized(inp)
self.assertEqual(out_optimized, out_eager, atol=5e-05, rtol=5e-06)
if expect_success:
self.assertEqual(counters["inductor"]["binary_folding"], 1)
else:
self.assertEqual(counters["inductor"]["binary_folding"], 0)
linear_bias = [True, False]
use_scalar = [True, False]
ops = [torch.add, torch.sub, torch.mul, torch.div]
add_tensor_size = [
[
32,
],
[1, 32],
[
1,
],
[1, 1],
]
for use_bias, pytorch_op, scalar, tensor_size in itertools.product(
linear_bias, ops, use_scalar, add_tensor_size
):
test_linear_fusion(
use_bias,
pytorch_op,
scalar,
add_tensor=torch.rand(tensor_size).to(self.device),
expect_success=True,
)
add_tensor_size.extend([[1, 1, 32], [1, 1, 1]])
for use_bias, pytorch_op, scalar, tensor_size in itertools.product(
linear_bias, ops, use_scalar, add_tensor_size
):
test_linear_fusion(
use_bias,
pytorch_op,
scalar,
add_tensor=torch.rand(tensor_size).to(self.device),
expect_success=True,
input_3d=True,
)
# In the following test, the shape of 'add_tensor' does not satisfy
# the requirements of binary folding, so it will not be folded.
for use_bias, pytorch_op in itertools.product(linear_bias, ops):
test_linear_fusion(
use_bias,
pytorch_op,
False,
add_tensor=torch.rand(
4,
32,
).to(self.device),
expect_success=False,
)
test_linear_fusion(
use_bias,
pytorch_op,
False,
add_tensor=torch.rand(
4,
1,
).to(self.device),
expect_success=False,
)
if HAS_CPU and not torch.backends.mps.is_available():
class FreezingCpuTests(TestCase):
common = check_model
device = "cpu"
autocast = torch.cpu.amp.autocast
copy_tests(BinaryFoldingTemplate, FreezingCpuTests, "cpu")
if HAS_GPU and not TEST_WITH_ASAN:
class FreezingGpuTests(TestCase):
common = check_model_gpu
device = GPU_TYPE
autocast = torch.amp.autocast(device_type=GPU_TYPE)
copy_tests(BinaryFoldingTemplate, FreezingGpuTests, GPU_TYPE)
del BinaryFoldingTemplate
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")