I try to adopt my app for hydra framework. I use structured config schema and I want to restrict possible values for some fields. Is there any way to do that?
Here is my code:
my_app.py:
JavaScript
x
17
17
1
import hydra
2
3
4
@dataclass
5
class Config:
6
# possible values are 'foo' and 'bar'
7
some_value: str = "foo"
8
9
10
@hydra.main(config_path="configs", config_name="config")
11
def main(cfg: Config):
12
print(cfg)
13
14
15
if __name__ == "__main__":
16
main()
17
configs/config.yaml:
JavaScript
1
4
1
# value is incorrect.
2
# I need hydra to throw an exception in this case
3
some_value: "barrr"
4
Advertisement
Answer
A few options:
1) If your acceptable values are enumerable, use an Enum
type:
JavaScript
1
12
12
1
from enum import Enum
2
from dataclasses import dataclass
3
4
class SomeValue(Enum):
5
foo = 1
6
bar = 2
7
8
@dataclass
9
class Config:
10
# possible values are 'foo' and 'bar'
11
some_value: SomeValue = SomeValue.foo
12
If no fancy logic is needed to validate some_value
, this is the solution I would recommend.
2) If you are using yaml files, you can use OmegaConf to register a custom resolver:
JavaScript
1
15
15
1
# my_python_file.py
2
from omegaconf import OmegaConf
3
4
def check_some_value(value: str) -> str:
5
assert value in ("foo", "bar")
6
return value
7
8
OmegaConf.register_new_resolver("check_foo_bar", check_some_value)
9
10
@hydra.main( )
11
12
13
if __name__ == "__main__":
14
main()
15
JavaScript
1
3
1
# my_yaml_file.yaml
2
some_value: ${check_foo_bar:foo}
3
When you access cfg.some_value
in your python code, an AssertionError
will be raised if the value does not agree with the check_some_value
function.
3) After config composition is completed, you can call OmegaConf.to_object
to create an instance of your dataclass. This means that the dataclass’s __post_init__
function will get called.
JavaScript
1
20
20
1
import hydra
2
from dataclasses import dataclass
3
from omegaconf import DictConfig, OmegaConf
4
5
@dataclass
6
class Config:
7
# possible values are 'foo' and 'bar'
8
some_value: str = "foo"
9
10
def __post_init__(self) -> None:
11
assert self.some_value in ("foo", "bar")
12
13
@hydra.main(config_path="configs", config_name="config")
14
def main(dict_cfg: DictConfg):
15
cfg: Config = OmegaConf.to_object(dict_cfg)
16
print(cfg)
17
18
if __name__ == "__main__":
19
main()
20