mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[1/N] [Caffe2] Remove caffe2_aten_fallback code (#128675)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128675 Approved by: https://github.com/r-barnes
This commit is contained in:
parent
8953725e6d
commit
163847b1bb
19 changed files with 33 additions and 529 deletions
|
|
@ -86,26 +86,6 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
|
|||
x = torch.ones(3)
|
||||
torch.onnx.export(foo, (x,), f)
|
||||
|
||||
@common_utils.skipIfNoCaffe2
|
||||
@common_utils.skipIfNoLapack
|
||||
def test_caffe2_aten_fallback(self):
|
||||
class ModelWithAtenNotONNXOp(nn.Module):
|
||||
def forward(self, x, y):
|
||||
abcd = x + y
|
||||
defg = torch.linalg.qr(abcd)
|
||||
return defg
|
||||
|
||||
x = torch.rand(3, 4)
|
||||
y = torch.rand(3, 4)
|
||||
torch.onnx.export_to_pretty_string(
|
||||
ModelWithAtenNotONNXOp(),
|
||||
(x, y),
|
||||
add_node_names=False,
|
||||
do_constant_folding=False,
|
||||
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
)
|
||||
|
||||
@common_utils.skipIfCaffe2
|
||||
@common_utils.skipIfNoLapack
|
||||
def test_aten_fallback(self):
|
||||
class ModelWithAtenNotONNXOp(nn.Module):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from torch.onnx.symbolic_helper import (
|
|||
parse_args,
|
||||
)
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack
|
||||
from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
|
||||
unittest.TestCase.maxDiff = None
|
||||
|
||||
|
|
@ -414,7 +414,6 @@ class TestOperators(common_utils.TestCase):
|
|||
x = torch.randn(20, 16, 50)
|
||||
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
|
||||
|
||||
@skipIfCaffe2
|
||||
def test_at_op(self):
|
||||
x = torch.randn(3, 4)
|
||||
|
||||
|
|
@ -694,7 +693,6 @@ class TestOperators(common_utils.TestCase):
|
|||
keep_initializers_as_inputs=True,
|
||||
)
|
||||
|
||||
@skipIfCaffe2
|
||||
def test_embedding_bags(self):
|
||||
emb_bag = nn.EmbeddingBag(10, 8)
|
||||
input = torch.tensor([1, 2, 3, 4]).long()
|
||||
|
|
@ -949,7 +947,6 @@ class TestOperators(common_utils.TestCase):
|
|||
other = torch.randint(-50, 50, (2, 3, 4), dtype=torch.int8)
|
||||
self.assertONNX(BiwiseAndModel(), (input, other), opset_version=18)
|
||||
|
||||
@skipIfCaffe2
|
||||
def test_layer_norm_aten(self):
|
||||
model = torch.nn.LayerNorm([10, 10])
|
||||
x = torch.randn(20, 5, 10, 10)
|
||||
|
|
@ -1203,7 +1200,6 @@ class TestOperators(common_utils.TestCase):
|
|||
torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
|
||||
|
||||
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
|
||||
@skipIfCaffe2
|
||||
def test_aten_embedding_2(self):
|
||||
_onnx_opset_version = 12
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import pytorch_test_common
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.onnx import OperatorExportTypes, symbolic_helper, utils
|
||||
from torch.onnx import symbolic_helper, utils
|
||||
from torch.onnx._internal import registration
|
||||
from torch.testing._internal import common_quantization, common_utils, jit_utils
|
||||
|
||||
|
|
@ -394,7 +394,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
for node in graph.nodes():
|
||||
self.assertTrue(node.sourceRange())
|
||||
|
||||
@common_utils.skipIfCaffe2
|
||||
def test_clip_aten_fallback_due_exception(self):
|
||||
def bad_clamp(g, self, min, max):
|
||||
return symbolic_helper._onnx_unsupported("Bad boy!")
|
||||
|
|
@ -411,7 +410,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
)
|
||||
self.assertAtenOp(onnx_model, "clamp", "Tensor")
|
||||
|
||||
@common_utils.skipIfCaffe2
|
||||
def test_clip_aten_fallback_explicit_request(self):
|
||||
class MyClip(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -961,60 +959,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
|
||||
torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))
|
||||
|
||||
@common_utils.skipIfNoCaffe2
|
||||
def test_caffe2_aten_fallback_must_fallback(self):
|
||||
class ModelWithAtenNotONNXOp(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
abcd = x + y
|
||||
defg = torch.linalg.qr(abcd)
|
||||
return defg
|
||||
|
||||
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
|
||||
for operator_export_type in (
|
||||
OperatorExportTypes.ONNX_ATEN,
|
||||
OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
):
|
||||
x = torch.rand(3, 4)
|
||||
y = torch.rand(3, 4)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
ModelWithAtenNotONNXOp(),
|
||||
(x, y),
|
||||
f,
|
||||
do_constant_folding=False,
|
||||
operator_export_type=operator_export_type,
|
||||
# support for linalg.qr was added in later op set versions.
|
||||
opset_version=9,
|
||||
)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertAtenOp(onnx_model, "linalg_qr")
|
||||
|
||||
@common_utils.skipIfNoCaffe2
|
||||
def test_caffe2_onnx_aten_must_not_fallback(self):
|
||||
class ModelWithAtenFmod(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.fmod(x, y)
|
||||
|
||||
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
|
||||
for operator_export_type in (
|
||||
OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
OperatorExportTypes.ONNX_ATEN,
|
||||
):
|
||||
x = torch.randn(3, 4, dtype=torch.float32)
|
||||
y = torch.randn(3, 4, dtype=torch.float32)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
ModelWithAtenFmod(),
|
||||
(x, y),
|
||||
f,
|
||||
do_constant_folding=False,
|
||||
operator_export_type=operator_export_type,
|
||||
opset_version=10, # or higher
|
||||
)
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
assert onnx_model.graph.node[0].op_type == "Mod"
|
||||
|
||||
@common_utils.skipIfCaffe2
|
||||
def test_aten_fallback_must_fallback(self):
|
||||
class ModelWithAtenNotONNXOp(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
@ -1037,7 +981,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertAtenOp(onnx_model, "linalg_qr")
|
||||
|
||||
@common_utils.skipIfCaffe2
|
||||
def test_onnx_aten(self):
|
||||
class ModelWithAtenFmod(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
@ -1056,7 +999,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertAtenOp(onnx_model, "fmod", "Tensor")
|
||||
|
||||
@common_utils.skipIfCaffe2
|
||||
def test_onnx_aten_fallback_must_not_fallback(self):
|
||||
# For BUILD_CAFFE2=0, aten fallback only when not exportable
|
||||
class ONNXExportable(torch.nn.Module):
|
||||
|
|
@ -1233,26 +1175,6 @@ class TestQuantizeEagerONNXExport(common_utils.TestCase):
|
|||
|
||||
_export_to_onnx(model, data, input_names)
|
||||
|
||||
@common_quantization.skipIfNoFBGEMM
|
||||
@common_utils.skipIfNoCaffe2
|
||||
def test_lower_graph_linear(self):
|
||||
model = torch.ao.quantization.QuantWrapper(
|
||||
torch.nn.Linear(5, 10, bias=True)
|
||||
).to(dtype=torch.float)
|
||||
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
|
||||
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
|
||||
self._test_lower_graph_impl(model, data)
|
||||
|
||||
@common_quantization.skipIfNoFBGEMM
|
||||
@common_utils.skipIfNoCaffe2
|
||||
def test_lower_graph_conv2d(self):
|
||||
model = torch.ao.quantization.QuantWrapper(
|
||||
torch.nn.Conv2d(3, 5, 2, bias=True)
|
||||
).to(dtype=torch.float)
|
||||
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
|
||||
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
|
||||
self._test_lower_graph_impl(model, data)
|
||||
|
||||
@common_quantization.skipIfNoFBGEMM
|
||||
@unittest.skip(
|
||||
"onnx opset9 does not support quantize_per_tensor and caffe2 \
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from pytorch_test_common import (
|
|||
skipIfUnsupportedMaxOpsetVersion,
|
||||
skipIfUnsupportedMinOpsetVersion,
|
||||
)
|
||||
from verify import verify
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
|
|
@ -26,7 +25,7 @@ from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
|
|||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx.symbolic_helper import _unpack_list, parse_args
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
|
||||
from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
|
||||
|
||||
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
|
||||
|
|
@ -1623,25 +1622,6 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
"Graph parameter names does not match model parameters.",
|
||||
)
|
||||
|
||||
@skipIfNoCaffe2
|
||||
def test_modifying_params(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.tensor([2.0]))
|
||||
|
||||
def forward(self, x):
|
||||
y = x * x
|
||||
self.param.data.add_(1.0)
|
||||
return y
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
# Move import to local as caffe2 backend requires additional build flag,
|
||||
# and is only used in this test case.
|
||||
import caffe2.python.onnx.backend as backend
|
||||
|
||||
verify(MyModel(), x, backend, do_constant_folding=False)
|
||||
|
||||
def test_fuse_conv_bn(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ hu.assert_deadline_disabled()
|
|||
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE
|
||||
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
|
||||
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
|
||||
override_quantized_engine, supported_qengines, override_qengines, _snr
|
||||
|
|
@ -4524,47 +4524,6 @@ class TestQuantizedEmbeddingOps(TestCase):
|
|||
self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight)
|
||||
|
||||
|
||||
""" Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
|
||||
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
|
||||
@given(num_embeddings=st.integers(10, 100),
|
||||
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
|
||||
num_batches=st.integers(1, 5),
|
||||
data_type=st.sampled_from([np.float32, np.float16]),)
|
||||
def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches, data_type):
|
||||
pack_fn = torch.ops.quantized.embedding_bag_byte_prepack
|
||||
unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack
|
||||
|
||||
self._test_embedding_bag_unpack_fn(
|
||||
pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches, data_type=data_type)
|
||||
|
||||
""" Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """
|
||||
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
|
||||
@given(num_embeddings=st.integers(10, 100),
|
||||
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
|
||||
optimized_qparams=st.booleans(),
|
||||
data_type=st.sampled_from([np.float32, np.float16]),)
|
||||
def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
|
||||
pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
|
||||
unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack
|
||||
|
||||
# 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
|
||||
self._test_embedding_bag_unpack_fn(
|
||||
pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type)
|
||||
|
||||
""" Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """
|
||||
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
|
||||
@given(num_embeddings=st.integers(10, 100),
|
||||
embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
|
||||
optimized_qparams=st.booleans(),
|
||||
data_type=st.sampled_from([np.float32, np.float16]),)
|
||||
def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
|
||||
pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
|
||||
unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
|
||||
|
||||
# 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
|
||||
self._test_embedding_bag_unpack_fn(
|
||||
pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type)
|
||||
|
||||
|
||||
def embedding_bag_rowwise_offsets_run(
|
||||
self, bit_rate, num_embeddings,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ import torch.nn.functional as F
|
|||
from torch.testing._internal import jit_utils
|
||||
from torch.testing._internal.common_jit import check_against_reference
|
||||
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
|
||||
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
|
||||
freeze_rng_state, slowTest, TemporaryFileName, \
|
||||
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
||||
skipIfCrossRef, skipIfTorchDynamo
|
||||
|
|
@ -15299,56 +15299,6 @@ dedent """
|
|||
continue
|
||||
self.assertEqual(value, getattr(loaded, "_" + name))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
|
||||
@unittest.skipIf(not BUILD_WITH_CAFFE2, "PyTorch is build without Caffe2 support")
|
||||
def test_old_models_bc(self):
|
||||
model = {
|
||||
'archive/version': b'1',
|
||||
'archive/code/archive.py':
|
||||
b'''
|
||||
op_version_set = 0
|
||||
def forward(self,
|
||||
_0: Tensor) -> Tensor:
|
||||
_1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))
|
||||
result = torch.to(torch.fill_(_1, 5), dtype=6, layout=0, device=torch.device("cpu"),
|
||||
non_blocking=False, copy=False)
|
||||
result2 = torch.rand([10], dtype=6, layout=0, device=torch.device("cpu"))
|
||||
result3 = torch.rand_like(result2, dtype=6, layout=0, device=torch.device("cpu"))
|
||||
_2 = torch.add(torch.add(result, result2, alpha=1), result3, alpha=1)
|
||||
return _2
|
||||
''',
|
||||
'archive/attributes.pkl': b'\x80\x02](e.',
|
||||
'archive/libs.py': b'op_version_set = 0\n',
|
||||
'archive/model.json':
|
||||
b'''
|
||||
{
|
||||
"protoVersion":"2",
|
||||
"mainModule":{
|
||||
"torchscriptArena":{
|
||||
"key":"code/archive.py"
|
||||
},
|
||||
"name":"archive",
|
||||
"optimize":true
|
||||
},
|
||||
"producerName":"pytorch",
|
||||
"producerVersion":"1.0",
|
||||
"libs":{
|
||||
"torchscriptArena":{
|
||||
"key":"libs.py"
|
||||
}
|
||||
}
|
||||
}'''}
|
||||
with TemporaryFileName() as fname:
|
||||
archive_name = os.path.basename(os.path.normpath(fname))
|
||||
with zipfile.ZipFile(fname, 'w') as archive:
|
||||
for k, v in model.items():
|
||||
archive.writestr(k, v)
|
||||
|
||||
with open(fname, "rb") as f:
|
||||
fn = torch.jit.load(f)
|
||||
|
||||
x = torch.zeros(10)
|
||||
fn(x)
|
||||
|
||||
def test_submodule_attribute_serialization(self):
|
||||
class S(torch.jit.ScriptModule):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
_CAFFE2_ATEN_FALLBACK: bool
|
||||
PRODUCER_VERSION: str
|
||||
|
||||
class TensorProtoDataType(Enum):
|
||||
|
|
|
|||
|
|
@ -292,7 +292,5 @@ void initONNXBindings(PyObject* module) {
|
|||
.value("TRAINING", TrainingMode::TRAINING);
|
||||
|
||||
onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
|
||||
|
||||
onnx.attr("_CAFFE2_ATEN_FALLBACK") = false;
|
||||
}
|
||||
} // namespace torch::onnx
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import (
|
||||
_CAFFE2_ATEN_FALLBACK,
|
||||
OperatorExportTypes,
|
||||
TensorProtoDataType,
|
||||
TrainingMode,
|
||||
)
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
|
||||
_deprecation,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Un
|
|||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype, registration
|
||||
|
||||
|
|
@ -329,14 +328,6 @@ def _scalar(x: torch.Tensor):
|
|||
return x[0]
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _is_caffe2_aten_fallback() -> bool:
|
||||
return (
|
||||
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
||||
r"""Initializes the right attribute based on type of value."""
|
||||
|
|
@ -350,16 +341,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
|||
if _is_onnx_list(value):
|
||||
kind += "s"
|
||||
|
||||
if aten and _is_caffe2_aten_fallback():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Caffe2 proto does not support tensor attribute.
|
||||
if value.numel() > 1:
|
||||
raise ValueError("Should not pass tensor attribute")
|
||||
value = _scalar(value)
|
||||
if isinstance(value, float):
|
||||
kind = "f"
|
||||
else:
|
||||
kind = "i"
|
||||
return getattr(node, f"{kind}_")(name, value)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -537,10 +537,7 @@ def is_complex_value(x: _C.Value) -> bool:
|
|||
|
||||
@_beartype.beartype
|
||||
def is_caffe2_aten_fallback() -> bool:
|
||||
return (
|
||||
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
|
|
@ -592,9 +589,7 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
|
|||
@_beartype.beartype
|
||||
def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
|
||||
# For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
|
||||
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
||||
warnings.warn(f"ONNX export failed on {op} because {msg} not supported")
|
||||
elif GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
|
||||
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
|
||||
_onnx_unsupported(f"{op}, {msg}", value)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -211,10 +211,6 @@ def index_put(
|
|||
indices_list = symbolic_helper._unpack_list(indices_list_value)
|
||||
else:
|
||||
indices_list = [indices_list_value]
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
args = [self] + indices_list + [values, accumulate]
|
||||
return g.at("index_put", *args)
|
||||
|
||||
accumulate = symbolic_helper._parse_arg(accumulate, "b")
|
||||
|
||||
if len(indices_list) == 0:
|
||||
|
|
@ -398,8 +394,6 @@ def __interpolate(
|
|||
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
|
||||
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
|
||||
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("gather", self, dim, index, sparse_grad)
|
||||
return g.op("GatherElements", self, index, axis_i=dim)
|
||||
|
||||
|
||||
|
|
@ -407,8 +401,6 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
|
|||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("scatter", self, dim, index, src, overload_name="src")
|
||||
src_type = _type_utils.JitScalarType.from_value(src)
|
||||
src = symbolic_helper._maybe_get_scalar(src)
|
||||
if symbolic_helper._is_value(src):
|
||||
|
|
@ -898,8 +890,6 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim):
|
|||
stop = g.op(
|
||||
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
|
||||
)
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.op("_caffe2::Range", stop)
|
||||
return arange(g, stop, 4, None, None, None)
|
||||
|
||||
|
||||
|
|
@ -982,9 +972,6 @@ def mm(g: jit_utils.GraphContext, self, other):
|
|||
@_onnx_symbolic("aten::index")
|
||||
@_beartype.beartype
|
||||
def index(g: jit_utils.GraphContext, self, index):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("index", self, index, overload_name="Tensor")
|
||||
|
||||
if symbolic_helper._is_packed_list(index):
|
||||
indices = symbolic_helper._unpack_list(index)
|
||||
else:
|
||||
|
|
@ -1007,16 +994,6 @@ def index(g: jit_utils.GraphContext, self, index):
|
|||
@_beartype.beartype
|
||||
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"index_fill",
|
||||
self,
|
||||
index,
|
||||
value,
|
||||
overload_name="int_Scalar",
|
||||
dim_i=dim_value,
|
||||
)
|
||||
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
|
|
@ -1030,8 +1007,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
|||
@_beartype.beartype
|
||||
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("index_copy", self, index, source, dim_i=dim_value)
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
|
|
|
|||
|
|
@ -330,8 +330,6 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
|||
const_step
|
||||
):
|
||||
return opset9.unfold(g, input, dimension, const_size, const_step)
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
|
||||
|
||||
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
|
||||
if sizedim is not None:
|
||||
|
|
|
|||
|
|
@ -71,9 +71,6 @@ def grid_sampler(
|
|||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
@_beartype.beartype
|
||||
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("scatter", self, dim, index, src, overload_name="src")
|
||||
|
||||
src_type = _type_utils.JitScalarType.from_value(
|
||||
src, _type_utils.JitScalarType.UNDEFINED
|
||||
)
|
||||
|
|
|
|||
|
|
@ -841,36 +841,18 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool =
|
|||
@symbolic_helper.parse_args("v", "i", "none")
|
||||
@_beartype.beartype
|
||||
def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
if dtype.node().kind() != "prim::Constant":
|
||||
return symbolic_helper._unimplemented("cumsum", "dtype", dtype)
|
||||
return g.at("cumsum", input, dim_i=dim)
|
||||
|
||||
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_sample_dirichlet")
|
||||
@_beartype.beartype
|
||||
def _sample_dirichlet(g: jit_utils.GraphContext, self, generator):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
if not symbolic_helper._is_none(generator):
|
||||
return symbolic_helper._unimplemented(
|
||||
"_sample_dirichlet", "We are not able to export generator", self
|
||||
)
|
||||
return g.at("_sample_dirichlet", self)
|
||||
return symbolic_helper._onnx_unsupported("_sample_dirichlet", self)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_standard_gamma")
|
||||
@_beartype.beartype
|
||||
def _standard_gamma(g: jit_utils.GraphContext, self, generator):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
if not symbolic_helper._is_none(generator):
|
||||
return symbolic_helper._unimplemented(
|
||||
"_standard_gamma", "not able to export generator", self
|
||||
)
|
||||
return g.at("_standard_gamma", self)
|
||||
|
||||
return symbolic_helper._onnx_unsupported("_standard_gamma", self)
|
||||
|
||||
|
||||
|
|
@ -1007,19 +989,6 @@ def embedding_bag(
|
|||
return symbolic_helper._onnx_unsupported(
|
||||
"embedding_bag with per_sample_weights"
|
||||
)
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"embedding_bag",
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
outputs=4,
|
||||
scale_grad_by_freq_i=scale_grad_by_freq,
|
||||
mode_i=mode,
|
||||
sparse_i=sparse,
|
||||
include_last_offset_i=include_last_offset,
|
||||
padding_idx_i=padding_idx,
|
||||
)
|
||||
|
||||
return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix)
|
||||
|
||||
|
|
@ -1052,10 +1021,6 @@ def transpose(g: jit_utils.GraphContext, self, dim0, dim1):
|
|||
axes = list(range(rank))
|
||||
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
|
||||
return g.op("Transpose", self, perm_i=axes)
|
||||
elif symbolic_helper.is_caffe2_aten_fallback():
|
||||
# if we don't have dim information we cannot
|
||||
# output a permute so use ATen instead
|
||||
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
|
||||
else:
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of transpose for tensor of unknown rank.",
|
||||
|
|
@ -2927,16 +2892,6 @@ def layer_norm(
|
|||
eps: float,
|
||||
cudnn_enable: bool,
|
||||
) -> _C.Value:
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"layer_norm",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
normalized_shape_i=normalized_shape,
|
||||
eps_f=eps,
|
||||
cudnn_enable_i=cudnn_enable,
|
||||
)
|
||||
normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps)
|
||||
return normalized
|
||||
|
||||
|
|
@ -3043,8 +2998,6 @@ def instance_norm(
|
|||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
|
||||
sizes = symbolic_helper._get_tensor_sizes(input)
|
||||
# FIXME(justinchuby): Get rid of the try catch here to improve readability
|
||||
try:
|
||||
|
|
@ -3119,9 +3072,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum
|
|||
indices_list = symbolic_helper._unpack_list(indices_list_value)
|
||||
else:
|
||||
indices_list = [indices_list_value]
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
args = [self] + indices_list + [values, accumulate]
|
||||
return g.at("index_put", *args)
|
||||
|
||||
accumulate = symbolic_helper._parse_arg(accumulate, "b")
|
||||
|
||||
|
|
@ -3136,16 +3086,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum
|
|||
@_beartype.beartype
|
||||
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"index_fill",
|
||||
self,
|
||||
index,
|
||||
value,
|
||||
overload_name="int_Scalar",
|
||||
dim_i=dim_value,
|
||||
)
|
||||
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
|
|
@ -3160,8 +3100,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
|||
@_beartype.beartype
|
||||
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||||
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("index_copy", self, index, source, dim_i=dim_value)
|
||||
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||||
g, self, dim, index
|
||||
)
|
||||
|
|
@ -3220,10 +3158,6 @@ def type_as(g: jit_utils.GraphContext, self, other):
|
|||
to_i=other_dtype.onnx_type(),
|
||||
)
|
||||
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
# We don't know the type of other, bail by emitting ATen
|
||||
return g.at("type_as", self, other)
|
||||
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of type_as for tensor "
|
||||
"of unknown dtype. Please check if the dtype of the "
|
||||
|
|
@ -3236,8 +3170,6 @@ def type_as(g: jit_utils.GraphContext, self, other):
|
|||
@symbolic_helper.parse_args("v", "v", "i", "f")
|
||||
@_beartype.beartype
|
||||
def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
|
||||
cross = symbolic_helper._reducesum_helper(
|
||||
g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0
|
||||
)
|
||||
|
|
@ -3516,50 +3448,28 @@ def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
|
|||
@symbolic_helper.parse_args("v", "v", "v", "i")
|
||||
@_beartype.beartype
|
||||
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("conv_tbc", input, weight, bias, pad_i=pad)
|
||||
else:
|
||||
# input must have 3 dimensions, see:
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
|
||||
# input = (time, batch, in_channels)
|
||||
# weight = (kernel_width, in_channels, out_channels)
|
||||
# bias = (out_channels,)
|
||||
input = g.op("Transpose", input, perm_i=[1, 2, 0])
|
||||
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
|
||||
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
|
||||
return g.op("Transpose", conv, perm_i=[2, 0, 1])
|
||||
# input must have 3 dimensions, see:
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
|
||||
# input = (time, batch, in_channels)
|
||||
# weight = (kernel_width, in_channels, out_channels)
|
||||
# bias = (out_channels,)
|
||||
input = g.op("Transpose", input, perm_i=[1, 2, 0])
|
||||
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
|
||||
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
|
||||
return g.op("Transpose", conv, perm_i=[2, 0, 1])
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_unique")
|
||||
@symbolic_helper.parse_args("v", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"_unique",
|
||||
input,
|
||||
sorted_i=sorted,
|
||||
return_inverse_i=return_inverse,
|
||||
outputs=2,
|
||||
)
|
||||
else:
|
||||
return symbolic_helper._onnx_unsupported("_unique", input)
|
||||
return symbolic_helper._onnx_unsupported("_unique", input)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::_unique2")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||||
@_beartype.beartype
|
||||
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"_unique2",
|
||||
input,
|
||||
sorted_i=sorted,
|
||||
return_inverse_i=return_inverse,
|
||||
return_counts_i=return_counts,
|
||||
outputs=3,
|
||||
)
|
||||
|
||||
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
|
||||
|
||||
|
||||
|
|
@ -4973,11 +4883,8 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim):
|
|||
stop = g.op(
|
||||
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
|
||||
)
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.op("_caffe2::Range", stop)
|
||||
else:
|
||||
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
return arange(g, stop, 4, None, None, None)
|
||||
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
return arange(g, stop, 4, None, None, None)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::detach")
|
||||
|
|
@ -5543,9 +5450,6 @@ def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
|||
@_onnx_symbolic("aten::arange")
|
||||
@_beartype.beartype
|
||||
def arange(g: jit_utils.GraphContext, *args):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("arange", *args)
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_arange_dtype(dtype):
|
||||
dtype = symbolic_helper._maybe_get_const(dtype, "i")
|
||||
|
|
@ -5665,9 +5569,6 @@ def masked_fill_(g: jit_utils.GraphContext, self, mask, value):
|
|||
@_onnx_symbolic("aten::index")
|
||||
@_beartype.beartype
|
||||
def index(g: jit_utils.GraphContext, self, index):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("index", self, index, overload_name="Tensor")
|
||||
|
||||
if symbolic_helper._is_packed_list(index):
|
||||
indices = symbolic_helper._unpack_list(index)
|
||||
else:
|
||||
|
|
@ -6083,17 +5984,6 @@ def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "no
|
|||
def group_norm(
|
||||
g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled
|
||||
):
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at(
|
||||
"group_norm",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
num_groups_i=num_groups,
|
||||
eps_f=eps,
|
||||
cudnn_enabled_i=cudnn_enabled,
|
||||
)
|
||||
|
||||
channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
|
||||
if channel_size is not None:
|
||||
assert channel_size % num_groups == 0
|
||||
|
|
@ -6169,9 +6059,6 @@ def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim):
|
|||
norm_v = norm(g, weight_v, 2, axes, 1)
|
||||
div = g.op("Div", weight_v, norm_v)
|
||||
return g.op("Mul", div, weight_g)
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
return g.at("_weight_norm", weight_v, weight_g, dim_i=dim)
|
||||
|
||||
raise errors.SymbolicValueError(
|
||||
"Unsupported: ONNX export of _weight_norm for tensor of unknown rank.",
|
||||
weight_v,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import copy
|
|||
import inspect
|
||||
import io
|
||||
import re
|
||||
import textwrap
|
||||
import typing
|
||||
import warnings
|
||||
from typing import (
|
||||
|
|
@ -681,27 +680,6 @@ def _optimize_graph(
|
|||
_C._jit_pass_onnx_unpack_quantized_weights(
|
||||
graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
|
||||
)
|
||||
if symbolic_helper.is_caffe2_aten_fallback():
|
||||
# Insert permutes before and after each conv op to ensure correct order.
|
||||
_C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
|
||||
|
||||
# Find consecutive permutes that are no-ops and remove them.
|
||||
_C._jit_pass_custom_pattern_based_rewrite_graph(
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
graph(%Pi):
|
||||
%Pq = quantized::nhwc2nchw(%Pi)
|
||||
%Pr = quantized::nchw2nhwc(%Pq)
|
||||
return (%Pr)"""
|
||||
),
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
graph(%Ri):
|
||||
return (%Ri)"""
|
||||
),
|
||||
graph,
|
||||
)
|
||||
|
||||
# onnx only supports tensors, so we turn all out number types into tensors
|
||||
_C._jit_pass_erase_number_types(graph)
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
|
|
@ -734,18 +712,9 @@ def _optimize_graph(
|
|||
graph = _C._jit_pass_canonicalize(graph)
|
||||
_C._jit_pass_lint(graph)
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
try:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if (
|
||||
_C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
and exc.args[0]
|
||||
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
|
||||
):
|
||||
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
|
||||
pass
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
|
@ -783,17 +752,6 @@ def warn_on_static_input_change(input_states):
|
|||
@_beartype.beartype
|
||||
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
|
||||
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
|
||||
if (
|
||||
operator_export_type is not operator_export_type.ONNX
|
||||
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
):
|
||||
if arg_value is True:
|
||||
warnings.warn(
|
||||
f"'{arg_name}' can be set to True only when 'operator_export_type' is "
|
||||
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
|
||||
f"'{arg_name}' argument will be ignored."
|
||||
)
|
||||
arg_value = False
|
||||
return arg_value
|
||||
|
||||
|
||||
|
|
@ -1298,18 +1256,9 @@ def _model_to_graph(
|
|||
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
if GLOBALS.onnx_shape_inference:
|
||||
try:
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if (
|
||||
_C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
and exc.args[0]
|
||||
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
|
||||
):
|
||||
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
|
||||
pass
|
||||
_C._jit_pass_onnx_graph_shape_type_inference(
|
||||
graph, params_dict, GLOBALS.export_onnx_opset_version
|
||||
)
|
||||
|
||||
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||
|
||||
|
|
@ -1612,15 +1561,6 @@ def _export(
|
|||
if export_type is None:
|
||||
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
|
||||
|
||||
# Discussed deprecation with Nikita Shulga and Sergii Dymchenko from Meta
|
||||
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
||||
warnings.warn(
|
||||
"Caffe2 ONNX exporter is deprecated in version 2.0 and will be "
|
||||
"removed in 2.2. Please use PyTorch 2.1 or older for this capability.",
|
||||
category=FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
raise ValueError(
|
||||
"torch.nn.DataParallel is not supported by ONNX "
|
||||
|
|
@ -1655,10 +1595,7 @@ def _export(
|
|||
"no local function support. "
|
||||
)
|
||||
if not operator_export_type:
|
||||
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
||||
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
else:
|
||||
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
|
||||
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
|
||||
|
||||
# By default, training=TrainingMode.EVAL,
|
||||
# which is good because running a model in training mode could result in
|
||||
|
|
@ -1904,21 +1841,12 @@ def _should_aten_fallback(
|
|||
is_aten_fallback_export = (
|
||||
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
)
|
||||
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK
|
||||
|
||||
if not name.startswith("aten::"):
|
||||
return False
|
||||
|
||||
if is_caffe2_build:
|
||||
if (
|
||||
is_onnx_aten_export or is_aten_fallback_export
|
||||
) and not is_exportable_aten_op:
|
||||
return True
|
||||
else:
|
||||
if is_onnx_aten_export or (
|
||||
is_aten_fallback_export and not is_exportable_aten_op
|
||||
):
|
||||
return True
|
||||
if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
|
@ -1968,7 +1896,7 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
|||
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
|
||||
schema = n.schema()
|
||||
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
|
||||
if not schema.startswith("aten::"):
|
||||
return ""
|
||||
return _C.parse_schema(schema).overload_name
|
||||
|
||||
|
|
@ -2032,14 +1960,7 @@ def _run_symbolic_function(
|
|||
)
|
||||
|
||||
try:
|
||||
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
|
||||
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
|
||||
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
|
||||
|
||||
if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
|
||||
domain = "caffe2"
|
||||
else:
|
||||
domain = namespace
|
||||
domain = namespace
|
||||
symbolic_function_name = f"{domain}::{op_name}"
|
||||
|
||||
symbolic_function_group = registration.registry.get_function_group(
|
||||
|
|
@ -2073,10 +1994,7 @@ def _run_symbolic_function(
|
|||
except RuntimeError:
|
||||
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
|
||||
return None
|
||||
elif (
|
||||
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
and not symbolic_helper.is_caffe2_aten_fallback()
|
||||
):
|
||||
elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
# Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
|
||||
attrs = {
|
||||
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
|
||||
|
|
|
|||
|
|
@ -633,10 +633,7 @@ def _onnx_graph_from_model(
|
|||
utils._setup_trace_module_map(model, export_modules_as_functions)
|
||||
|
||||
if not operator_export_type:
|
||||
if _C_onnx._CAFFE2_ATEN_FALLBACK:
|
||||
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
else:
|
||||
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
|
||||
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
|
||||
|
||||
GLOBALS.export_onnx_opset_version = opset_version
|
||||
GLOBALS.operator_export_type = operator_export_type
|
||||
|
|
|
|||
|
|
@ -329,14 +329,6 @@ def skipIfNoQNNPACK(fn):
|
|||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not torch.onnx._CAFFE2_ATEN_FALLBACK:
|
||||
raise unittest.SkipTest(reason)
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def withQNNPACKBackend(fn):
|
||||
# TODO(future PR): consider combining with skipIfNoQNNPACK,
|
||||
# will require testing of existing callsites
|
||||
|
|
|
|||
|
|
@ -1252,8 +1252,6 @@ TEST_OPT_EINSUM = _check_module_exists('opt_einsum')
|
|||
|
||||
TEST_Z3 = _check_module_exists('z3')
|
||||
|
||||
BUILD_WITH_CAFFE2 = torch.onnx._CAFFE2_ATEN_FALLBACK
|
||||
|
||||
def split_if_not_empty(x: str):
|
||||
return x.split(",") if len(x) != 0 else []
|
||||
|
||||
|
|
@ -1886,19 +1884,6 @@ def skipIfNotRegistered(op_name, message):
|
|||
"""
|
||||
return unittest.skip("Pytorch is compiled without Caffe2")
|
||||
|
||||
def _decide_skip_caffe2(expect_caffe2, reason):
|
||||
def skip_dec(func):
|
||||
@wraps(func)
|
||||
def wrapper(self):
|
||||
if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2:
|
||||
raise unittest.SkipTest(reason)
|
||||
return func(self)
|
||||
return wrapper
|
||||
return skip_dec
|
||||
|
||||
skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2")
|
||||
skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available")
|
||||
|
||||
def skipIfNoSciPy(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in a new issue