diff --git a/torch/_export/db/logging.py b/torch/_export/db/logging.py index 2078113fef1..d034e4d4d41 100644 --- a/torch/_export/db/logging.py +++ b/torch/_export/db/logging.py @@ -1,7 +1,6 @@ -# mypy: allow-untyped-defs +from typing import Optional - -def exportdb_error_message(case_name: str): +def exportdb_error_message(case_name: str) -> str: from .examples import all_examples from torch._utils_internal import log_export_usage @@ -19,7 +18,7 @@ def exportdb_error_message(case_name: str): return f"{case_name} is unsupported." -def get_class_if_classified_error(e): +def get_class_if_classified_error(e: Exception) -> Optional[str]: """ Returns a string case name if the export error e is classified. Returns None otherwise.