config: create Config objects with JK support (#138766)

This teaches install_config_module (and the underlying code) to
understands Config objects. Additionally we've added a JK option to this
which resolves the JK.

This config gets stored within the _ConfigEntry class and is evaluated
when __getattr__ is called. If justknobs is set, it'll call
justknobs_check to see the result.

Due to preceeding work, basically everything works correctly here and we
had to update a couple of tests, and modify the getattr behaviour.

Note that we are updating the justknob_check function to support a
default option, to make default work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138766
Approved by: https://github.com/ezyang
This commit is contained in:
Colin L. Rice 2024-11-01 10:13:07 -06:00 committed by PyTorch MergeBot
parent 6fc63b4ef1
commit abc5d59dcb
4 changed files with 79 additions and 11 deletions

View file

@ -90,6 +90,9 @@ class TestConfigModule(TestCase):
"e_compile_ignored": True,
"magic_cache_config_ignored": True,
"_save_config_ignore": ["e_ignored"],
"e_config": True,
"e_jk": True,
"e_jk_false": False,
},
)
config.e_bool = False
@ -117,6 +120,9 @@ class TestConfigModule(TestCase):
"nested.e_bool": True,
"e_ignored": True,
"e_compile_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
},
)
config.e_bool = False
@ -146,29 +152,28 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
def test_get_hash(self):
self.assertEqual(
config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
)
# Test cached value
self.assertEqual(
config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
)
self.assertEqual(
config._hash_digest, b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
)
config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake")
# BUG
config.e_bool = False
self.assertNotEqual(
config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
)
config.e_bool = True
# Test ignored values
config.e_compile_ignored = False
self.assertEqual(
config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6"
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
@ -194,6 +199,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
},
)
p2 = config.to_dict()
@ -216,6 +224,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
},
)
p3 = config.get_config_copy()
@ -238,6 +249,9 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
},
)

View file

@ -280,7 +280,7 @@ def justknobs_feature(
return justknobs_check(name)
def justknobs_check(name: str) -> bool:
def justknobs_check(name: str, default: bool = True) -> bool:
"""
This function can be used to killswitch functionality in FB prod,
where you can toggle this value to False in JK without having to
@ -303,7 +303,7 @@ def justknobs_check(name: str) -> bool:
fork safe and you will break anyone who forks the process and then
hits JK again.
"""
return True
return default
def justknobs_getval_int(name: str) -> int:

View file

@ -1,7 +1,7 @@
import sys
from typing import Optional
from torch.utils._config_module import install_config_module
from torch.utils._config_module import Config, install_config_module
e_bool = True
@ -18,6 +18,9 @@ _e_ignored = True
magic_cache_config_ignored = True
# [@compile_ignored: debug]
e_compile_ignored = True
e_config = Config(default=True)
e_jk = Config(justknob="does_not_exist")
e_jk_false = Config(justknob="does_not_exist", default=False)
class nested:

View file

@ -13,6 +13,40 @@ from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union
from typing_extensions import deprecated
from unittest import mock
from torch._utils_internal import justknobs_check
@dataclass
class Config:
"""Represents a config with richer behaviour than just a default value.
::
i.e.
foo = Config(justknob="//foo:bar", default=False)
install_config_module(...)
This configs must be installed with install_config_module to be used
Precedence Order:
user_override: If a user sets a value (i.e. foo.bar=True), that
has the highest precendance and is always respected
justknob: If this pytorch installation supports justknobs, that will
override defaults, but will not override the user_override precendence.
default: This value is the lowest precendance, and will be used if nothing is
set.
Arguments:
justknob: the name of the feature / JK. In OSS this is unused.
default: is the value to default this knob to in OSS.
"""
default: Any = True
justknob: Optional[str] = None
def __init__(self, default: Any = True, justknob: Optional[str] = None):
# python 3.9 does not support kw_only on the dataclass :(.
self.default = default
self.justknob = justknob
# Types saved/loaded in configs
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
@ -39,12 +73,18 @@ def install_config_module(module: ModuleType) -> None:
key.startswith("__")
or isinstance(value, (ModuleType, FunctionType))
or (hasattr(value, "__module__") and value.__module__ == "typing")
# Handle from torch.utils._config_module import Config
or (isinstance(value, type) and issubclass(value, Config))
):
continue
name = f"{prefix}{key}"
if isinstance(value, CONFIG_TYPES):
config[name] = _ConfigEntry(default=value)
config[name] = _ConfigEntry(Config(default=value))
if dest is module:
delattr(module, key)
elif isinstance(value, Config):
config[name] = _ConfigEntry(value)
if dest is module:
delattr(module, key)
elif isinstance(value, type):
@ -123,6 +163,12 @@ class _ConfigEntry:
# The value specified by the user when they overrode the configuration
# _UNSET_SENTINEL indicates the value is not set.
user_override: Any = _UNSET_SENTINEL
# The justknob to check for this config
justknob: Optional[str] = None
def __init__(self, config: Config):
self.default = config.default
self.justknob = config.justknob
class ConfigModule(ModuleType):
@ -157,6 +203,10 @@ class ConfigModule(ModuleType):
if config.user_override is not _UNSET_SENTINEL:
return config.user_override
if config.justknob is not None:
# JK only supports bools and ints
return justknobs_check(name=config.justknob, default=config.default)
# Note that reference types can still be modified, so we
# copy them to user_overrides in case the user overrides
# them
@ -164,6 +214,7 @@ class ConfigModule(ModuleType):
config.user_override = copy.deepcopy(config.default)
return config.user_override
return config.default
except KeyError as e:
# make hasattr() work properly
raise AttributeError(f"{self.__name__}.{name} does not exist") from e
@ -204,7 +255,7 @@ class ConfigModule(ModuleType):
if ignored_keys and key in ignored_keys:
if skip_default and not self._is_default(key):
warnings.warn(
f"Skipping serialization of {key} value {self._config[key]}"
f"Skipping serialization of {key} value {getattr(self, key)}"
)
continue
if ignored_prefixes: