diff --git a/test/test_utils_config_module.py b/test/test_utils_config_module.py index c8e56381abd..add6af20b02 100644 --- a/test/test_utils_config_module.py +++ b/test/test_utils_config_module.py @@ -6,6 +6,8 @@ import pickle os.environ["ENV_TRUE"] = "1" os.environ["ENV_FALSE"] = "0" +from typing import Optional + from torch.testing._internal import fake_config_module as config from torch.testing._internal.common_utils import run_tests, TestCase from torch.utils._config_module import _UNSET_SENTINEL @@ -15,6 +17,7 @@ class TestConfigModule(TestCase): def test_base_value_loading(self): self.assertTrue(config.e_bool) self.assertTrue(config.nested.e_bool) + self.assertTrue(config.e_optional) self.assertEqual(config.e_int, 1) self.assertEqual(config.e_float, 1.0) self.assertEqual(config.e_string, "string") @@ -28,6 +31,10 @@ class TestConfigModule(TestCase): ): config.does_not_exist + def test_type_loading(self): + self.assertEqual(config.get_type("e_optional"), Optional[bool]) + self.assertEqual(config.get_type("e_none"), Optional[bool]) + def test_overrides(self): config.e_bool = False self.assertFalse(config.e_bool) @@ -51,6 +58,10 @@ class TestConfigModule(TestCase): self.assertEqual(config.e_none, "not none") config.e_none = None self.assertEqual(config.e_none, None) + config.e_optional = None + self.assertEqual(config.e_optional, None) + config.e_optional = False + self.assertEqual(config.e_optional, False) with self.assertRaises( AttributeError, msg="fake_config_module.does_not_exist does not exist" ): @@ -112,6 +123,7 @@ class TestConfigModule(TestCase): "e_env_default": True, "e_env_default_FALSE": False, "e_env_force": True, + "e_optional": True, }, ) config.e_bool = False @@ -145,6 +157,7 @@ class TestConfigModule(TestCase): "e_env_default": True, "e_env_default_FALSE": False, "e_env_force": True, + "e_optional": True, }, ) config.e_bool = False @@ -173,30 +186,22 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" config._config[k].user_override = _UNSET_SENTINEL def test_get_hash(self): - self.assertEqual( - config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" - ) + self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") # Test cached value - self.assertEqual( - config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" - ) - self.assertEqual( - config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" - ) + self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") + self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") config._hash_digest = "fake" self.assertEqual(config.get_hash(), "fake") config.e_bool = False self.assertNotEqual( - config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" + config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ" ) config.e_bool = True # Test ignored values config.e_compile_ignored = False - self.assertEqual( - config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" - ) + self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") for k in config._config: config._config[k].user_override = _UNSET_SENTINEL @@ -227,6 +232,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" "e_env_default": True, "e_env_default_FALSE": False, "e_env_force": True, + "e_optional": True, }, ) p2 = config.to_dict() @@ -255,6 +261,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" "e_env_default": True, "e_env_default_FALSE": False, "e_env_force": True, + "e_optional": True, }, ) p3 = config.get_config_copy() @@ -283,6 +290,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']"" "e_env_default": True, "e_env_default_FALSE": False, "e_env_force": True, + "e_optional": True, }, ) diff --git a/torch/__init__.py b/torch/__init__.py index 2b879a31740..970ac7d55e2 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -25,6 +25,7 @@ from typing import ( Any as _Any, Callable as _Callable, Dict as _Dict, + get_origin as _get_origin, Optional as _Optional, overload as _overload, Set as _Set, @@ -2280,13 +2281,18 @@ class _TorchCompileInductorWrapper: raise RuntimeError( f"Unexpected optimization option {key}, known options are {list(current_config.keys())}" ) - if type(val) is not type(current_config[attr_name]): - val_type_str = type(val).__name__ - expected_type_str = type(current_config[attr_name]).__name__ - raise RuntimeError( - f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}" - ) - self.config[attr_name] = val + attr_type = config.get_type(attr_name) # type: ignore[attr-defined] + # Subscriptable generic types don't support isinstance so skip the type + # check. There doesn't seem to be a good way of checking membership without + # 3rd party libraries. + if _get_origin(attr_type) is None: + if not isinstance(val, attr_type): + val_type_str = type(val).__name__ + expected_type_str = type(current_config[attr_name]).__name__ + raise RuntimeError( + f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}" + ) + self.config[attr_name] = val def __call__(self, model_, inputs_): from torch._inductor.compile_fx import compile_fx diff --git a/torch/testing/_internal/fake_config_module.py b/torch/testing/_internal/fake_config_module.py index 1d5bed8fe0e..5ceb692b2dd 100644 --- a/torch/testing/_internal/fake_config_module.py +++ b/torch/testing/_internal/fake_config_module.py @@ -13,6 +13,7 @@ e_set = {1} e_tuple = (1,) e_dict = {1: 2} e_none: Optional[bool] = None +e_optional: Optional[bool] = True e_ignored = True _e_ignored = True magic_cache_config_ignored = True diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 41569465d25..a097d42cf19 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -5,6 +5,7 @@ import inspect import io import os import pickle +import sys import tokenize import unittest import warnings @@ -54,6 +55,7 @@ class Config: justknob: Optional[str] = None env_name_default: Optional[str] = None env_name_force: Optional[str] = None + value_type: Optional[type] = None def __init__( self, @@ -61,12 +63,14 @@ class Config: justknob: Optional[str] = None, env_name_default: Optional[str] = None, env_name_force: Optional[str] = None, + value_type: Optional[type] = None, ): # python 3.9 does not support kw_only on the dataclass :(. self.default = default self.justknob = justknob self.env_name_default = env_name_default self.env_name_force = env_name_force + self.value_type = value_type # Types saved/loaded in configs @@ -99,6 +103,10 @@ def install_config_module(module: ModuleType) -> None: prefix: str, ) -> None: """Walk the module structure and move everything to module._config""" + if sys.version_info[:2] < (3, 10): + type_hints = getattr(source, "__annotations__", {}) + else: + type_hints = inspect.get_annotations(source) for key, value in list(source.__dict__.items()): if ( key.startswith("__") @@ -111,7 +119,10 @@ def install_config_module(module: ModuleType) -> None: name = f"{prefix}{key}" if isinstance(value, CONFIG_TYPES): - config[name] = _ConfigEntry(Config(default=value)) + annotated_type = type_hints.get(key, None) + config[name] = _ConfigEntry( + Config(default=value, value_type=annotated_type) + ) if dest is module: delattr(module, key) elif isinstance(value, Config): @@ -192,6 +203,8 @@ _UNSET_SENTINEL = object() class _ConfigEntry: # The default value specified in the configuration default: Any + # The type of the configuration value + value_type: type # The value specified by the user when they overrode the configuration # _UNSET_SENTINEL indicates the value is not set. user_override: Any = _UNSET_SENTINEL @@ -203,6 +216,9 @@ class _ConfigEntry: def __init__(self, config: Config): self.default = config.default + self.value_type = ( + config.value_type if config.value_type is not None else type(self.default) + ) self.justknob = config.justknob if config.env_name_default is not None: if (env_value := _read_env_variable(config.env_name_default)) is not None: @@ -314,6 +330,9 @@ class ConfigModule(ModuleType): config[key] = copy.deepcopy(getattr(self, key)) return config + def get_type(self, config_name: str) -> type: + return self._config[config_name].value_type + def save_config(self) -> bytes: """Convert config to a pickled blob""" ignored_keys = getattr(self, "_save_config_ignore", [])