From dd349207c5dced7218f4caa1820754010f3ff0b0 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Wed, 5 Feb 2025 19:40:10 +0000 Subject: [PATCH] Add check that envvar configs are boolean (#145454) So we don't get unexpected behavior when higher typed values are passed in Pull Request resolved: https://github.com/pytorch/pytorch/pull/145454 Approved by: https://github.com/c00w, https://github.com/jamesjwu --- test/test_utils_config_module.py | 16 ++++++++++++++-- torch/utils/_config_module.py | 22 ++++++++++++++++------ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/test/test_utils_config_module.py b/test/test_utils_config_module.py index 01d64c2343b..07892daab99 100644 --- a/test/test_utils_config_module.py +++ b/test/test_utils_config_module.py @@ -17,7 +17,7 @@ from torch.testing._internal import ( fake_config_module3 as config3, ) from torch.testing._internal.common_utils import run_tests, TestCase -from torch.utils._config_module import _UNSET_SENTINEL, Config +from torch.utils._config_module import _ConfigEntry, _UNSET_SENTINEL, Config class TestConfigModule(TestCase): @@ -378,7 +378,7 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""", AssertionError, msg="AssertionError: justknobs only support booleans, thisisnotvalid is not a boolean", ): - Config(default="bad", justknob="fake_knob") + _ConfigEntry(Config(default="bad", justknob="fake_knob")) def test_alias(self): self.assertFalse(config2.e_aliasing_bool) @@ -395,6 +395,18 @@ torch.testing._internal.fake_config_module3.e_func = _warnings.warn""", t["a"] = "b" self.assertFalse(config._is_default("e_dict")) + def test_invalid_config_int(self): + with self.assertRaises(AssertionError): + _ConfigEntry( + Config(default=2, env_name_default="FAKE_DISABLE", value_type=int) + ) + + def test_invalid_config_float(self): + with self.assertRaises(AssertionError): + _ConfigEntry( + Config(default=2, env_name_force="FAKE_DISABLE", value_type=float) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 28e3684577a..bd74462747a 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -82,7 +82,6 @@ class _Config(Generic[T]): justknob: Optional[str] = None env_name_default: Optional[list[str]] = None env_name_force: Optional[list[str]] = None - value_type: Optional[type] = None alias: Optional[str] = None def __init__( @@ -103,17 +102,13 @@ class _Config(Generic[T]): self.env_name_force = _Config.string_or_list_of_string_to_list(env_name_force) self.value_type = value_type self.alias = alias - if self.justknob is not None: - assert isinstance( - self.default, bool - ), f"justknobs only support booleans, {self.default} is not a boolean" if self.alias is not None: assert ( default is _UNSET_SENTINEL and justknob is None and env_name_default is None and env_name_force is None - ), "if alias is set, default, justknob or env var cannot be set" + ), "if alias is set, none of {default, justknob and env var} can be set" @staticmethod def string_or_list_of_string_to_list( @@ -326,6 +321,21 @@ class _ConfigEntry: self.env_value_force = env_value break + # Ensure justknobs and envvars are allowlisted types + if self.justknob is not None and self.default is not None: + assert isinstance( + self.default, bool + ), f"justknobs only support booleans, {self.default} is not a boolean" + if self.value_type is not None and ( + config.env_name_default is not None or config.env_name_force is not None + ): + assert self.value_type in ( + bool, + str, + Optional[bool], + Optional[str], + ), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + class ConfigModule(ModuleType): # NOTE: This should be kept in sync with _config_typing.pyi.