Skip to content
Advertisement

Return all extra passed to pydantic model

I’m trying to get a list of all extra fields not defined in the schema. I saw solution posted here but it ignores any nested models. Optimal solution would create a variable in the Pydantic model with extras that I could access after new object with passed data is created but not sure if this is even possible.

Here is the code I’m working with.

Edit: I want .extras to be something like a property that returns not just the extra data directly on that instance, but also the extra data on any nested model instances it holds

from typing import Any, Dict, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field, root_validator

unnecessary_data = {
    "name": "Lévy",
    "age": 3,
    "key_parent": "value",  # unnecessary
    "key2_parent": "value2",  # unnecessary x2
    "address": {
        "city": "Wonderland",
        "zip_code": "ABCDE",
        "number": 123,
        "key_child": 1232 # unnecessary x
    }
}


class NewBase(BaseModel):
    versio: Optional[str] = Field(alias='version')  # just to show that it supports alias too
    extra: Dict[str, Any]

    @root_validator(pre=True)
    def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        all_required_field_names = {field.alias for field in cls.__fields__.values() if field.alias != 'extra'}  # to support alias

        extra: Dict[str, Any] = {}
    for field_name in list(values):
        if field_name not in all_required_field_names:
            extra[field_name] = values.pop(field_name)
    values['extra'] = extra
    return values

class Address(NewBase):

    """
    Cat API Address definition
    """
    city: str
    zip_code: str
    number: int


class CatRequest(NewBase):
    """
    Cat API Request definition
    """
    name: str
    age: int
    address: Address


validated = CatRequest(**unnecessary_data)
print(validated.extras)
>> ["key_parent", "key2_parent", "address.key_child"]

Advertisement

Answer

The following solution does not produce a list of keys as you described, but instead a nested dictionary of key-value-pairs from the extra attributes:

from __future__ import annotations

from functools import cache
from typing import Any

from pydantic import BaseModel, root_validator


class NewBase(BaseModel):
    extra: dict[str, Any]

    @classmethod
    @cache
    def required_names(cls) -> set[str]:
        """This is just to make validation more efficient"""
        return {
            field.alias
            for field in cls.__fields__.values()
            if field.alias != 'extra'
        }

    @root_validator(pre=True)
    def build_extra(cls, values: dict[str, Any]) -> dict[str, Any]:
        extra: dict[str, Any] = {}
        for field_name in list(values.keys()):
            if field_name not in cls.required_names():
                extra[field_name] = values.pop(field_name)
        values['extra'] = extra
        return values

    def get_nested_extras(
        self,
        exclude: list[NewBase] | None = None,
    ) -> dict[str, Any]:
        """Recursively retrieves all nested `extra` attributes."""
        if exclude is None:
            exclude = []
        # To avoid infinite recursion,
        # we need to track which model instances have been checked already:
        if self not in exclude:
            exclude.append(self)
        output = self.extra.copy()
        for field_name in self.__fields__.keys():
            obj = getattr(self, field_name)
            if isinstance(obj, NewBase) and obj not in exclude:
                output[field_name] = obj.get_nested_extras(exclude=exclude)
        return output

    @property
    def extras(self) -> dict[str, Any]:
        return self.get_nested_extras()


class Address(NewBase):
    city: str
    zip_code: str
    number: int


class CatRequest(NewBase):
    name: str
    age: int
    address: Address


if __name__ == '__main__':
    data = {
        "name": "Lévy",
        "age": 3,
        "key_parent": "value",  # extra
        "key2_parent": "value2",  # extra
        "address": {
            "city": "Wonderland",
            "zip_code": "ABCDE",
            "number": 123,
            "key_child": 1232  # extra
        }
    }
    validated = CatRequest(**data)
    print(validated.extras)

Output:

{'key_parent': 'value', 'key2_parent': 'value2', 'address': {'key_child': 1232}}

If you do want your list of keys instead, you can use these methods:

    ...

    def get_nested_extra_fields(
        self,
        exclude: list[NewBase] | None = None,
    ) -> list[str]:
        """Recursively retrieves all nested `extra` keys."""
        if exclude is None:
            exclude = []
        # To avoid infinite recursion,
        # we need to track which model instances have been checked already:
        if self not in exclude:
            exclude.append(self)
        output = list(self.extra.keys())
        for field_name in self.__fields__.keys():
            obj = getattr(self, field_name)
            if isinstance(obj, NewBase) and obj not in exclude:
                nested_fields = obj.get_nested_extra_fields(exclude=exclude)
                output.extend(f"{field_name}.{k}" for k in nested_fields)
        return output

    @property
    def extra_fields(self) -> list[str]:
        return self.get_nested_extra_fields()

Calling extra_fields on the previous example model instance gives the following output:

['key_parent', 'key2_parent', 'address.key_child']

Both solutions just recursively iterate over all fields except extra.

The annotations assume Python 3.10+. If this causes problems, replace all types T | None with typing.Optional[T]. If you are using <3.9, (first of all, upgrade your Python :P) replace things like list[str] with typing.List[str].

The cached method for retrieving the required_names is just for efficiency, so that it is only ever called once for any given model class.

I left the build_extra root validator basically unchanged.

Caveat:

In the current implementation, if you have nested models in container fields, their extras are ignored. For example, if you had something like addresses: list[Address] on your CatRequest.

If I find the time, I’ll try and amend the solution later. Though I suspect that this may be non-trivial because of the different “shapes” such fields can come in. Also it is not entirely clear how that should look in the output.

Hope this helps.

Advertisement