mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] Update docs (#142011)
Summary: Update export docs. Including: 1. Update the output graph. 2. Misc fixes for examples. Test Plan: CI Differential Revision: D66726729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142011 Approved by: https://github.com/angelayi
This commit is contained in:
parent
471017cbc9
commit
31f2d4eb4e
3 changed files with 426 additions and 170 deletions
|
|
@ -103,6 +103,7 @@ of the Graph of GraphModule.
|
|||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
|
@ -110,9 +111,18 @@ Example::
|
|||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
mod = torch.export.export(MyModule())
|
||||
example_args = (torch.randn(1), torch.randn(1))
|
||||
mod = torch.export.export(MyModule(), example_args)
|
||||
print(mod.graph)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%y : [num_users=1] = placeholder[target=y]
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
|
||||
return (add,)
|
||||
|
||||
The above is the textual representation of a Graph, with each line being a node.
|
||||
|
||||
Node
|
||||
|
|
|
|||
|
|
@ -39,27 +39,40 @@ serialized.
|
|||
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
|
||||
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
|
||||
# code: a = torch.sin(x)
|
||||
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
|
||||
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
|
||||
|
||||
# code: b = torch.cos(y)
|
||||
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
|
||||
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
|
||||
|
||||
# code: return a + b
|
||||
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
|
||||
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
|
||||
return (add,)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
parameters=[],
|
||||
buffers=[],
|
||||
user_inputs=['arg0_1', 'arg1_1'],
|
||||
user_outputs=['add'],
|
||||
inputs_to_parameters={},
|
||||
inputs_to_buffers={},
|
||||
buffers_to_mutate={},
|
||||
backward_signature=None,
|
||||
assertion_dep_token=None,
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='y'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='add'),
|
||||
target=None
|
||||
)
|
||||
]
|
||||
)
|
||||
Range constraints: {}
|
||||
|
||||
|
|
@ -183,37 +196,53 @@ example:
|
|||
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
|
||||
|
||||
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
|
||||
# code: a = self.conv(x)
|
||||
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
|
||||
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
|
||||
);
|
||||
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])
|
||||
|
||||
# code: a.add_(constant)
|
||||
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
|
||||
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)
|
||||
|
||||
# code: return self.maxpool(self.relu(a))
|
||||
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
|
||||
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
|
||||
relu, [3, 3], [3, 3]
|
||||
);
|
||||
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
|
||||
return (getitem,)
|
||||
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
|
||||
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
|
||||
return (max_pool2d,)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
|
||||
buffers=[],
|
||||
user_inputs=['arg2_1', 'arg3_1'],
|
||||
user_outputs=['getitem'],
|
||||
inputs_to_parameters={
|
||||
'arg0_1': 'L__self___conv.weight',
|
||||
'arg1_1': 'L__self___conv.bias',
|
||||
},
|
||||
inputs_to_buffers={},
|
||||
buffers_to_mutate={},
|
||||
backward_signature=None,
|
||||
assertion_dep_token=None,
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_weight'),
|
||||
target='conv.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_bias'),
|
||||
target='conv.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='constant'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='max_pool2d'),
|
||||
target=None
|
||||
)
|
||||
]
|
||||
)
|
||||
Range constraints: {}
|
||||
|
||||
|
|
@ -336,25 +365,69 @@ To show some examples:
|
|||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
|
||||
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
|
||||
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None
|
||||
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
|
||||
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
|
||||
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
|
||||
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
|
||||
return (batch_norm,)
|
||||
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_weight'),
|
||||
target='conv.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_bias'),
|
||||
target='conv.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_bn_weight'),
|
||||
target='bn.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_bn_bias'),
|
||||
target='bn.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_running_mean'),
|
||||
target='bn.running_mean',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_running_var'),
|
||||
target='bn.running_var',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_num_batches_tracked'),
|
||||
target='bn.num_batches_tracked',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='batch_norm'), target=None)
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='batch_norm'),
|
||||
target=None
|
||||
)
|
||||
]
|
||||
)
|
||||
Range constraints: {}
|
||||
|
|
@ -380,36 +453,93 @@ You can also go from this IR to an inference IR via :func:`run_decompositions` w
|
|||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
|
||||
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
|
||||
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
|
||||
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
|
||||
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
|
||||
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
|
||||
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
|
||||
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
|
||||
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
|
||||
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
||||
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
|
||||
return (getitem_3, getitem_4, add, getitem)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_weight'),
|
||||
target='conv.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_bias'),
|
||||
target='conv.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_bn_weight'),
|
||||
target='bn.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_bn_bias'),
|
||||
target='bn.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_running_mean'),
|
||||
target='bn.running_mean',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_running_var'),
|
||||
target='bn.running_var',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_num_batches_tracked'),
|
||||
target='bn.num_batches_tracked',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'),
|
||||
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_4'), target='bn.running_var'),
|
||||
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'),
|
||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)
|
||||
OutputSpec(
|
||||
kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='getitem_3'),
|
||||
target='bn.running_mean'
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='getitem_4'),
|
||||
target='bn.running_var'
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='add'),
|
||||
target='bn.num_batches_tracked'
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='getitem'),
|
||||
target=None
|
||||
)
|
||||
]
|
||||
)
|
||||
Range constraints: {}
|
||||
|
||||
Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR
|
||||
containing core aten operators except for `conv2d`.
|
||||
Here you can see that we kept ``conv2d`` op in the IR while decomposing the rest. Now the IR is a functional IR
|
||||
containing core aten operators except for ``conv2d``.
|
||||
|
||||
You can do even more customization by directly registering your chosen decomposition behaviors.
|
||||
|
||||
|
|
@ -433,31 +563,88 @@ You can do even more customizations by directly registering custom decomp behavi
|
|||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
|
||||
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
|
||||
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None
|
||||
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
|
||||
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
|
||||
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
|
||||
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
|
||||
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
|
||||
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
|
||||
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
|
||||
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
|
||||
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
||||
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
|
||||
return (getitem_3, getitem_4, add, getitem)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_weight'),
|
||||
target='conv.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv_bias'),
|
||||
target='conv.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_bn_weight'),
|
||||
target='bn.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_bn_bias'),
|
||||
target='bn.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_running_mean'),
|
||||
target='bn.running_mean',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_running_var'),
|
||||
target='bn.running_var',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_bn_num_batches_tracked'),
|
||||
target='bn.num_batches_tracked',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'),
|
||||
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_4'), target='bn.running_var'),
|
||||
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'),
|
||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)
|
||||
OutputSpec(
|
||||
kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='getitem_3'),
|
||||
target='bn.running_mean'
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='getitem_4'),
|
||||
target='bn.running_var'
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='add'),
|
||||
target='bn.num_batches_tracked'
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='getitem'),
|
||||
target=None
|
||||
)
|
||||
]
|
||||
)
|
||||
Range constraints: {}
|
||||
|
|
@ -511,57 +698,93 @@ run. Such dimensions must be specified by using the
|
|||
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
|
||||
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
|
||||
|
||||
# code: out1 = self.branch1(x1)
|
||||
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
|
||||
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
|
||||
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
|
||||
linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
|
||||
relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)
|
||||
|
||||
# code: out2 = self.branch2(x2)
|
||||
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
|
||||
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
|
||||
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
|
||||
linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
|
||||
relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)
|
||||
|
||||
# code: return (out1 + self.buffer, out2)
|
||||
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
|
||||
add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
|
||||
return (add, relu_1)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
parameters=[
|
||||
'branch1.0.weight',
|
||||
'branch1.0.bias',
|
||||
'branch2.0.weight',
|
||||
'branch2.0.bias',
|
||||
],
|
||||
buffers=['L__self___buffer'],
|
||||
user_inputs=['arg5_1', 'arg6_1'],
|
||||
user_outputs=['add', 'relu_1'],
|
||||
inputs_to_parameters={
|
||||
'arg0_1': 'branch1.0.weight',
|
||||
'arg1_1': 'branch1.0.bias',
|
||||
'arg2_1': 'branch2.0.weight',
|
||||
'arg3_1': 'branch2.0.bias',
|
||||
},
|
||||
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
|
||||
buffers_to_mutate={},
|
||||
backward_signature=None,
|
||||
assertion_dep_token=None,
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_branch1_0_weight'),
|
||||
target='branch1.0.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_branch1_0_bias'),
|
||||
target='branch1.0.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_branch2_0_weight'),
|
||||
target='branch2.0.weight',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_branch2_0_bias'),
|
||||
target='branch2.0.bias',
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.CONSTANT_TENSOR: 4>,
|
||||
arg=TensorArgument(name='c_buffer'),
|
||||
target='buffer',
|
||||
persistent=True
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x1'),
|
||||
target=None,
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x2'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='add'),
|
||||
target=None
|
||||
),
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='relu_1'),
|
||||
target=None
|
||||
)
|
||||
]
|
||||
)
|
||||
Range constraints: {s0: VR[0, int_oo]}
|
||||
|
||||
Some additional things to note:
|
||||
|
||||
* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first
|
||||
dimension of each input to be dynamic. Looking at the inputs ``arg5_1`` and
|
||||
``arg6_1``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
|
||||
dimension of each input to be dynamic. Looking at the inputs ``x1`` and
|
||||
``x2``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
|
||||
the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs.
|
||||
``s0`` is a symbol representing that this dimension can be a range
|
||||
of values.
|
||||
|
||||
* ``exported_program.range_constraints`` describes the ranges of each symbol
|
||||
appearing in the graph. In this case, we see that ``s0`` has the range
|
||||
[2, inf]. For technical reasons that are difficult to explain here, they are
|
||||
[0, int_oo]. For technical reasons that are difficult to explain here, they are
|
||||
assumed to be not 0 or 1. This is not a bug, and does not necessarily mean
|
||||
that the exported program will not work for dimensions 0 or 1. See
|
||||
`The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`_
|
||||
|
|
@ -591,21 +814,37 @@ another, or a shape is even. An example:
|
|||
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
|
||||
def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
|
||||
# code: return x + y[1:]
|
||||
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None
|
||||
add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None
|
||||
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
|
||||
add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
|
||||
return (add,)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
Graph signature:
|
||||
ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None
|
||||
),
|
||||
InputSpec(
|
||||
kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='y'),
|
||||
target=None,
|
||||
persistent=None
|
||||
)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
|
||||
OutputSpec(
|
||||
kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='add'),
|
||||
target=None
|
||||
)
|
||||
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
|
||||
]
|
||||
)
|
||||
Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}
|
||||
|
||||
Some things to note:
|
||||
|
||||
|
|
@ -613,8 +852,8 @@ Some things to note:
|
|||
shape of the first input is now dynamic, being ``[s0]``. And now by specifying
|
||||
``{0: dimy}`` for the second input, we see that the resulting shape of the
|
||||
second input is also dynamic. However, because we expressed ``dimy = dimx + 1``,
|
||||
instead of ``arg1_1``'s shape containing a new symbol, we see that it is
|
||||
now being represented with the same symbol used in ``arg0_1``, ``s0``. We can
|
||||
instead of ``y``'s shape containing a new symbol, we see that it is
|
||||
now being represented with the same symbol used in ``x``, ``s0``. We can
|
||||
see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``.
|
||||
|
||||
* Looking at the range constraints, we see that ``s0`` has the range [3, 6],
|
||||
|
|
@ -701,8 +940,9 @@ that is being taken with the given sample inputs. For example:
|
|||
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[10, 2]):
|
||||
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
|
||||
def forward(self, x: "f32[10, 2]"):
|
||||
# code: return x + 1
|
||||
add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
|
||||
return (add,)
|
||||
|
||||
The conditional of (``x.shape[0] > 5``) does not appear in the
|
||||
|
|
@ -745,19 +985,20 @@ For example:
|
|||
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
|
||||
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
|
||||
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
|
||||
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
|
||||
def forward(self, x: "f32[2, 2]", const, times):
|
||||
# code: x = x + const
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
|
||||
return (add_2,)
|
||||
|
||||
Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations
|
||||
are all computed with the hard-coded constant ``1``, rather than ``arg1_1``. If
|
||||
a user passes a different value for ``arg1_1`` at runtime, like 2, than the one used
|
||||
are all computed with the hard-coded constant ``1``, rather than ``const``. If
|
||||
a user passes a different value for ``const`` at runtime, like 2, than the one used
|
||||
during export time, 1, this will result in an error.
|
||||
Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined"
|
||||
in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the
|
||||
input ``arg2_1`` is never used.
|
||||
input ``times`` is never used.
|
||||
|
||||
Python Containers
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -589,22 +589,27 @@ def register_dataclass(
|
|||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InputDataClass:
|
||||
feature: torch.Tensor
|
||||
bias: int
|
||||
|
||||
@dataclass
|
||||
class OutputDataClass:
|
||||
res: torch.Tensor
|
||||
|
||||
torch.export.register_dataclass(InputDataClass)
|
||||
torch.export.register_dataclass(OutputDataClass)
|
||||
|
||||
def fn(o: InputDataClass) -> torch.Tensor:
|
||||
res = res=o.feature + o.bias
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x: InputDataClass) -> OutputDataClass:
|
||||
res = x.feature + x.bias
|
||||
return OutputDataClass(res=res)
|
||||
|
||||
ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
|
||||
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
|
||||
print(ep)
|
||||
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue