mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] Fix graph_break log registration error when importing export/_trace.py (#131523)
Summary:
When importing `_trace.py`, put `torch._dynamo.exc.Unsupported` in the global variable ``_ALLOW_LIST`` can cause import to ``export/_trace.py`` to fail with error:
ValueError: Artifact name: 'graph_breaks' not registered, please call register_artifact('graph_breaks') in torch._logging.registrations.
The error is directly raise on line `graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")` in `_dynamo/exc.py`. I've checked that ``register_artifact('graph_breaks')`` does already exist in torch._logging.registrations.
Explicitly call `import torch._logging` doesn't fix the issue.
(see T196719676)
We move ``_ALLOW_LIST`` to be a local variable.
Test Plan:
buck2 test 'fbcode//mode/opt' fbcode//aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test -- --exact 'aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test - test_serialized_model_for_disagg_acc (aiplatform.modelstore.publish.utils.tests.fc_transform_utils_test.PrepareSerializedModelTest)'
buck2 test 'fbcode//mode/opt' fbcode//aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test -- --exact 'aiplatform/modelstore/publish/utils/tests:fc_transform_utils_test - test_serialized_test_dsnn_module (aiplatform.modelstore.publish.utils.tests.fc_transform_utils_test.PrepareSerializedModelTest)'
Differential Revision: D60136706
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131523
Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
236e06f9f9
commit
29c9f8c782
1 changed files with 7 additions and 7 deletions
|
|
@ -991,14 +991,14 @@ _EXPORT_FLAGS: Optional[Set[str]] = None
|
|||
_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
_ALLOW_LIST = {
|
||||
torch._dynamo.exc.Unsupported,
|
||||
torch._dynamo.exc.UserError,
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
}
|
||||
|
||||
|
||||
def _get_class_if_classified_error(e):
|
||||
from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError
|
||||
|
||||
_ALLOW_LIST = {
|
||||
Unsupported,
|
||||
UserError,
|
||||
TorchRuntimeError,
|
||||
}
|
||||
case_name = getattr(e, "case_name", None)
|
||||
if type(e) in _ALLOW_LIST and case_name is not None:
|
||||
return case_name
|
||||
|
|
|
|||
Loading…
Reference in a new issue