I try to adopt my app for hydra framework. I use structured config schema and I want to restrict possible values for some fields. Is there any way to do that?
Here is my code:
my_app.py:
import hydra @dataclass class Config: # possible values are 'foo' and 'bar' some_value: str = "foo" @hydra.main(config_path="configs", config_name="config") def main(cfg: Config): print(cfg) if __name__ == "__main__": main()
configs/config.yaml:
# value is incorrect. # I need hydra to throw an exception in this case some_value: "barrr"
Advertisement
Answer
A few options:
1) If your acceptable values are enumerable, use an Enum
type:
from enum import Enum from dataclasses import dataclass class SomeValue(Enum): foo = 1 bar = 2 @dataclass class Config: # possible values are 'foo' and 'bar' some_value: SomeValue = SomeValue.foo
If no fancy logic is needed to validate some_value
, this is the solution I would recommend.
2) If you are using yaml files, you can use OmegaConf to register a custom resolver:
# my_python_file.py from omegaconf import OmegaConf def check_some_value(value: str) -> str: assert value in ("foo", "bar") return value OmegaConf.register_new_resolver("check_foo_bar", check_some_value) @hydra.main(...) ... if __name__ == "__main__": main()
# my_yaml_file.yaml some_value: ${check_foo_bar:foo}
When you access cfg.some_value
in your python code, an AssertionError
will be raised if the value does not agree with the check_some_value
function.
3) After config composition is completed, you can call OmegaConf.to_object
to create an instance of your dataclass. This means that the dataclass’s __post_init__
function will get called.
import hydra from dataclasses import dataclass from omegaconf import DictConfig, OmegaConf @dataclass class Config: # possible values are 'foo' and 'bar' some_value: str = "foo" def __post_init__(self) -> None: assert self.some_value in ("foo", "bar") @hydra.main(config_path="configs", config_name="config") def main(dict_cfg: DictConfg): cfg: Config = OmegaConf.to_object(dict_cfg) print(cfg) if __name__ == "__main__": main()