【问题标题】:Validate classes (hydra-core lists) with pydantic使用 pydantic 验证类(hydra-core 列表)
【发布时间】:2022-11-02 22:29:18
【问题描述】:

1. 背景

如何验证 pydantic 中的特定类?

我正在使用pydantic 来验证由hydra 解析的yaml 列表参数,以便稍后传递给建模例程。问题是 hydra 字典不包含值列表,而是包含这些值的类。如何验证这些参数?

2. 例子

在以下示例中,有 2 个文件:

  • cfg.yaml 包含要验证的参数
  • main.py 包含加载和验证 cfg.yaml 的指令

2.1 配置文件cfg.yaml

params_list:
  - 10
  - 0
  - 20

2.2 解析器/验证器文件main.py

import hydra
import pydantic
from omegaconf import DictConfig, OmegaConf
from typing import List

class Test(pydantic.BaseModel):
    params_list: List[int]

@hydra.main(config_path=".", config_name="cfg.yaml")
def go(cfg: DictConfig):
    parsed_cfg = Test(**cfg)
    print(parsed_cfg)

if __name__ == "__main__":
    go()

3.问题

执行python3 main.py时出现以下错误

值不是有效列表(type=type_error.list)

那是因为 hydra 有一个特定的类来处理列表,称为omegaconf.listconfig.ListConfig,可以通过添加来检查

print(type(cfg['params_list']))

go() 函数定义之后。

4. 指导

我知道我可能必须告诉pydantic 来验证这个特定的东西,但我只是不知道具体如何。

  • Here 提供了一些提示,但对于我猜的任务来说似乎很重要。
  • 另一个想法是为数据属性创建一个泛型类型(如params_list: Generic),然后使用验证器装饰器将其转换为列表,大致如下:
class ParamsList(pydantic.BaseModel):
  params_list: ???????? #i don't know that to do here
  @p.validator("params_list")
  @classmethod
    def validate_path(cls, v) -> None:
        """validate if it's a list"""
        if type(list(v)) != list:
            raise TypeError("It's not a list. Make it become a list")
        return list(v)

帮助!:关于如何解决它的任何想法?

如何重新创建示例

  1. 在文件夹中添加第 2.1 和 2.2 节中描述的文件。
  2. 还使用包pydantichydra-core 创建一个requirements.txt 文件
  3. 创建并激活环境后,运行python3 main.py

【问题讨论】:

  • 在将数据传递给Test 类之前,您是否考虑在DictConfig 对象上调用OmegaConf.to_container?例如:parsed_cfg = Test(**OmegaConf.to_container(cfg))

标签: python validation pydantic fb-hydra


【解决方案1】:

Pydantic 不接受 DictConfig 格式。当您尝试使用 pydantic 模型解析 hydra 配置时,您必须首先将 DictConfig 转换为原生 Python Dict。您可以使用OmegaConf.to_object(cfg) 执行此操作。

我假设您使用 Python 3.10 或更高版本。注意使用version_base="1.2" 来获取最新的 hydra 版本。

这应该有效:

import hydra
import pydantic
from omegaconf import DictConfig, OmegaConf

class Test(pydantic.BaseModel):
    params_list: list[int]


@hydra.main(config_path=".", config_name="cfg.yaml", version_base="1.2")
def go(cfg: DictConfig):
    print(cfg)
    d_cfg = OmegaConf.to_object(cfg)
    parsed_cfg = Test(**d_cfg)
    print(parsed_cfg)


if __name__ == "__main__":
    go()

【讨论】:

    猜你喜欢
    • 2021-08-14
    • 2021-04-04
    • 2019-09-04
    • 1970-01-01
    • 2021-05-24
    • 1970-01-01
    • 2021-09-06
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多