Serializing Pydantic with Metadata and Generic Fields

I recently need a solution for serializing and deserializing Python. We initially
used pickle, but it was always a stop-gap, and it became an issue for debugging.

We switched to Pydantic,
which provided JSON serialization, but it left a problem with generic fields.

We had messages like this:

1
2
3
4
5
6
7
8
9
10
11
from pydantic import BaseModel

class Shape(BaseModel):
name: str

class Color(BaseModel):
background: str = "black"
foreground: str

class Update[T: BaseModel](BaseModel):
model: T

This worked for pickle, as it squirrelled away the python metadata, but it was
out of scope for pydantic which wanted to know all the candidate types in order to
provide a union.

Fortunately pydantic provides custom serialization.

The Serializer

Our serializer looks like this:

1
2
3
4
5
6
7
8
from pydantic import BaseModel, ValidationInfo


def serialize_model_to_dict(model: BaseModel) -> dict:
dct = model.model_dump(mode='json')
dct['__module__'] = model.__class__.__module__
dct['__qualname__'] = model.__class__.__qualname__
return dct

The serializer allows us to separate the stages of JSON serialization. First we
break the object into a JSON style python dictionary (using mode='json'),
which ensures all of the dictionary values are serializable by JSON. Then we add
the metadata: the class’s module name and qualified name worked well. We can
then pass this on to pydantic to handle the rest.

The deserializer

In pydantic, deserialization is part of “validation”. The validation part
looked like this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
import importlib

from pydantic import BaseModel, ValidationInfo


def deserialize_model_from_dict(dct: dict) -> BaseModel:
module = dct.pop('__module__')
qualname = dct.pop('__qualname__')
cls = getattr(importlib.import_module(module), qualname)
model = cls.model_validate(dct)
return model


def deserialize_model_from_json(data: str | bytes | bytearray) -> BaseModel:
dct = json.loads(data)
model = deserialize_model_from_dict(dct)
return model


def validate_model(value: BaseModel | dict | str | bytes | bytearray, info: ValidationInfo) -> BaseModel:

match info.mode:

case 'python':

if isinstance(value, BaseModel):
return value
elif isinstance(value, dict):
return deserialize_model_from_dict(value)
else:
raise ValueError(
f"unhandled type for mode {info.mode}: {type(value)}"
)

case 'json':

if isinstance(value, (str, bytes, bytearray)):
return deserialize_model_from_json(value)
elif isinstance(value, dict):
return deserialize_model_from_dict(value)
else:
raise ValueError(
f"unhandled type for mode {info.mode}: {type(value)}"
)

case _:

raise ValueError(f"Invalid mode: {info.mode}")

There’s a bunch of code here. The key part is in deserialize_model_from_dict
where we can see the inverse of the serialization. It takes the module name and
class name from the supplied dictionary and creates the model class, from which it validates
the model.

The rest of the code provides the plumbing for pydantic. The entrypoint is the
validate_model function. This gets called via a number of different routes.
The first is when a python model is created, when the info.mode will be
'python' and the value will be of type BaseModel.

The second is when JSON data is provided when the info.mode is 'json' and
the type is some kind of text data.

The third is during serialization when a dictionary is generated by an
intermediate step. Here the info.mode will be either 'python' or 'json'
and the type of the value is a dict.

The Model Field Attributes

This all gets wired up in the following manner:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Annotated

from pydantic import (
BaseModel,
PlainSerializer,
PlainValidator
)

from .serialization import serialize_model_to_dict, validate_model

class Update[T: BaseModel](BaseModel):
model: Annotated[
T,
PlainSerializer(serialize_model_to_dict),
PlainValidator(validate_model)
]

Note how we have to use typing annotations to pass in the serializer and validator.

The generated JSON looks as follows:

1
2
3
4
5
6
7
8
9
10
>> model = Update(
model=User(
name='John Doe',
date_of_birth=datetime(1990, 1, 1),
height=1.75
)
)
>> text = update.model_dump_json()
>> print(text)
{"model":{"name":"John Doe","date_of_birth":"1990-01-01T00:00:00","height":1.75,"__module__":"kafka_ex1.models","__qualname__":"User"}}

We keep the dunder names to avoid collisions and to flatter the pythonistas.

Whole Message Serialization

It turned out that it would be useful to know the meta data of the root message, as we wanted to save all the messages to a data store and replay them. This turned out to require just one extra function.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import json
import importlib

from pydantic import BaseModel, ValidationInfo


def serialize_model_to_dict(model: BaseModel) -> dict:
dct = model.model_dump(mode='json')
dct['__module__'] = model.__class__.__module__
dct['__qualname__'] = model.__class__.__qualname__
return dct


def serialize_model_to_json_str(model: BaseModel) -> str:
dct = serialize_model_to_dict(model)
return json.dumps(dct)

Now we can create the metadata at the root level.

1
2
3
>>> update = Update(model=user)
>>> serialize_model_to_json_str(update)
{"model": {"name": "John Doe", "date_of_birth": "1990-01-01T00:00:00", "height": 1.75, "__module__": "demo.models", "__qualname__": "User"}, "__module__": "demo.models", "__qualname__": "Update"}

We can see there is metadata at the root level.

Finally we can do a full roundtrip with any base model:

1
2
3
>>> update = Update(model=user)
>>> text = serialize_model_to_json_str(update)
>>> roundtrip = deserialize_model_from_json(text)

Now we can save and retrieve any model. Happy days!

You can find the code for this blog here.