[export] Update docs (#142011)

Summary:
Update export docs. Including:
1. Update the output graph.
2. Misc fixes for examples.

Test Plan: CI

Differential Revision: D66726729

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142011
Approved by: https://github.com/angelayi
This commit is contained in:
Yiming Zhou 2024-12-05 03:44:44 +00:00 committed by PyTorch MergeBot
parent 471017cbc9
commit 31f2d4eb4e
3 changed files with 426 additions and 170 deletions

View file

@ -103,15 +103,25 @@ of the Graph of GraphModule.
Example::
from torch import nn
import torch
from torch import nn
class MyModule(nn.Module):
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
def forward(self, x, y):
return x + y
mod = torch.export.export(MyModule())
print(mod.graph)
example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
.. code-block:: python
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
return (add,)
The above is the textual representation of a Graph, with each line being a node.

View file

@ -39,28 +39,41 @@ serialized.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# code: a = torch.sin(x)
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
# code: b = torch.cos(y)
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
return (add,)
Graph signature: ExportGraphSignature(
parameters=[],
buffers=[],
user_inputs=['arg0_1', 'arg1_1'],
user_outputs=['add'],
inputs_to_parameters={},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {}
``torch.export`` produces a clean intermediate representation (IR) with the
@ -183,39 +196,55 @@ example:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# code: a = self.conv(x)
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
);
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])
# code: a.add_(constant)
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)
# code: return self.maxpool(self.relu(a))
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
relu, [3, 3], [3, 3]
);
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
return (getitem,)
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
return (max_pool2d,)
Graph signature: ExportGraphSignature(
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
buffers=[],
user_inputs=['arg2_1', 'arg3_1'],
user_outputs=['getitem'],
inputs_to_parameters={
'arg0_1': 'L__self___conv.weight',
'arg1_1': 'L__self___conv.bias',
},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='constant'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='max_pool2d'),
target=None
)
]
)
Range constraints: {}
Range constraints: {}
Inspecting the ``ExportedProgram``, we can note the following:
@ -336,25 +365,69 @@ To show some examples:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
return (batch_norm,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='batch_norm'), target=None)
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='batch_norm'),
target=None
)
]
)
Range constraints: {}
@ -380,36 +453,93 @@ You can also go from this IR to an inference IR via :func:`run_decompositions` w
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
return (getitem_3, getitem_4, add, getitem)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'),
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_4'), target='bn.running_var'),
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'),
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)
]
)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_3'),
target='bn.running_mean'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_4'),
target='bn.running_var'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='bn.num_batches_tracked'
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='getitem'),
target=None
)
]
)
Range constraints: {}
Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR
containing core aten operators except for `conv2d`.
Here you can see that we kept ``conv2d`` op in the IR while decomposing the rest. Now the IR is a functional IR
containing core aten operators except for ``conv2d``.
You can do even more customization by directly registering your chosen decomposition behaviors.
@ -433,32 +563,89 @@ You can do even more customizations by directly registering custom decomp behavi
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
return (getitem_3, getitem_4, add, getitem)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'),
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_4'), target='bn.running_var'),
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'),
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)
]
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_weight'),
target='bn.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_bn_bias'),
target='bn.bias',
persistent=None
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_mean'),
target='bn.running_mean',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_running_var'),
target='bn.running_var',
persistent=True
),
InputSpec(
kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_bn_num_batches_tracked'),
target='bn.num_batches_tracked',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_3'),
target='bn.running_mean'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='getitem_4'),
target='bn.running_var'
),
OutputSpec(
kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='bn.num_batches_tracked'
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='getitem'),
target=None
)
]
)
Range constraints: {}
@ -510,58 +697,94 @@ run. Such dimensions must be specified by using the
.. code-block::
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
# code: out1 = self.branch1(x1)
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
# code: out1 = self.branch1(x1)
linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)
# code: out2 = self.branch2(x2)
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# code: out2 = self.branch2(x2)
linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)
# code: return (out1 + self.buffer, out2)
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
return (add, relu_1)
# code: return (out1 + self.buffer, out2)
add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
return (add, relu_1)
Graph signature: ExportGraphSignature(
parameters=[
'branch1.0.weight',
'branch1.0.bias',
'branch2.0.weight',
'branch2.0.bias',
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch1_0_weight'),
target='branch1.0.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch1_0_bias'),
target='branch1.0.bias',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch2_0_weight'),
target='branch2.0.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_branch2_0_bias'),
target='branch2.0.bias',
persistent=None
),
InputSpec(
kind=<InputKind.CONSTANT_TENSOR: 4>,
arg=TensorArgument(name='c_buffer'),
target='buffer',
persistent=True
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x1'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x2'),
target=None,
persistent=None
)
],
buffers=['L__self___buffer'],
user_inputs=['arg5_1', 'arg6_1'],
user_outputs=['add', 'relu_1'],
inputs_to_parameters={
'arg0_1': 'branch1.0.weight',
'arg1_1': 'branch1.0.bias',
'arg2_1': 'branch2.0.weight',
'arg3_1': 'branch2.0.bias',
},
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
),
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='relu_1'),
target=None
)
]
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
Range constraints: {s0: VR[0, int_oo]}
Some additional things to note:
* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first
dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and
``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
dimension of each input to be dynamic. Looking at the inputs ``x1`` and
``x2``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs.
``s0`` is a symbol representing that this dimension can be a range
of values.
* ``exported_program.range_constraints`` describes the ranges of each symbol
appearing in the graph. In this case, we see that ``s0`` has the range
[2, inf]. For technical reasons that are difficult to explain here, they are
[0, int_oo]. For technical reasons that are difficult to explain here, they are
assumed to be not 0 or 1. This is not a bug, and does not necessarily mean
that the exported program will not work for dimensions 0 or 1. See
`The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`_
@ -591,21 +814,37 @@ another, or a shape is even. An example:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None
add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
return (add,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}
Some things to note:
@ -613,8 +852,8 @@ Some things to note:
shape of the first input is now dynamic, being ``[s0]``. And now by specifying
``{0: dimy}`` for the second input, we see that the resulting shape of the
second input is also dynamic. However, because we expressed ``dimy = dimx + 1``,
instead of ``arg1_1``'s shape containing a new symbol, we see that it is
now being represented with the same symbol used in ``arg0_1``, ``s0``. We can
instead of ``y``'s shape containing a new symbol, we see that it is
now being represented with the same symbol used in ``x``, ``s0``. We can
see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``.
* Looking at the range constraints, we see that ``s0`` has the range [3, 6],
@ -700,10 +939,11 @@ that is being taken with the given sample inputs. For example:
.. code-block::
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 2]):
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
return (add,)
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 2]"):
# code: return x + 1
add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
return (add,)
The conditional of (``x.shape[0] > 5``) does not appear in the
``ExportedProgram`` because the example inputs have the static
@ -745,19 +985,20 @@ For example:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
def forward(self, x: "f32[2, 2]", const, times):
# code: x = x + const
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
return (add_2,)
Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations
are all computed with the hard-coded constant ``1``, rather than ``arg1_1``. If
a user passes a different value for ``arg1_1`` at runtime, like 2, than the one used
are all computed with the hard-coded constant ``1``, rather than ``const``. If
a user passes a different value for ``const`` at runtime, like 2, than the one used
during export time, 1, this will result in an error.
Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined"
in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the
input ``arg2_1`` is never used.
input ``times`` is never used.
Python Containers
~~~~~~~~~~~~~~~~~

View file

@ -589,22 +589,27 @@ def register_dataclass(
Example::
import torch
from dataclasses import dataclass
@dataclass
class InputDataClass:
feature: torch.Tensor
bias: int
@dataclass
class OutputDataClass:
res: torch.Tensor
torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)
def fn(o: InputDataClass) -> torch.Tensor:
res = res=o.feature + o.bias
return OutputDataClass(res=res)
class Mod(torch.nn.Module):
def forward(self, x: InputDataClass) -> OutputDataClass:
res = x.feature + x.bias
return OutputDataClass(res=res)
ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
"""