[export] Fix graph_break log registration error when importing export/_trace.py (#131523)

Summary:
When importing `_trace.py`, put `torch._dynamo.exc.Unsupported` in the global variable ``_ALLOW_LIST`` can cause import to ``export/_trace.py`` to fail with error:

ValueError: Artifact name: 'graph_breaks' not registered, please call register_artifact('graph_breaks') in torch._logging.registrations.

The error is directly raise on line `graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")` in `_dynamo/exc.py`. I've checked that ``register_artifact('graph_breaks')`` does already exist in torch._logging.registrations.

Explicitly call `import torch._logging` doesn't fix the issue.

(see T196719676)

We move ``_ALLOW_LIST`` to be a local variable.

Test Plan:
buck2 test 'fbcode//mode/opt' fbcode//aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test -- --exact 'aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test - test_serialized_model_for_disagg_acc (aiplatform.modelstore.publish.utils.tests.fc_transform_utils_test.PrepareSerializedModelTest)'

buck2 test 'fbcode//mode/opt' fbcode//aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test -- --exact 'aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test - test_serialized_test_dsnn_module (aiplatform.modelstore.publish.utils.tests.fc_transform_utils_test.PrepareSerializedModelTest)'

Differential Revision: D60136706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131523
Approved by: https://github.com/zhxchen17
This commit is contained in:
Shangdi Yu 2024-07-24 22:40:24 +00:00 committed by PyTorch MergeBot
parent 236e06f9f9
commit 29c9f8c782

View file

@ -991,14 +991,14 @@ _EXPORT_FLAGS: Optional[Set[str]] = None
_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None
_ALLOW_LIST = {
torch._dynamo.exc.Unsupported,
torch._dynamo.exc.UserError,
torch._dynamo.exc.TorchRuntimeError,
}
def _get_class_if_classified_error(e):
from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError
_ALLOW_LIST = {
Unsupported,
UserError,
TorchRuntimeError,
}
case_name = getattr(e, "case_name", None)
if type(e) in _ALLOW_LIST and case_name is not None:
return case_name