Skip to content
Advertisement

How do I construct a self-referential/recursive SQLModel

I want to define a model that has a self-referential (or recursive) foreign key using SQLModel. (This relationship pattern is also sometimes referred to as an adjacency list.) The pure SQLAlchemy implementation is described here in their documentation.

Let’s say I want to implement the basic tree structure as described in the SQLAlchemy example linked above, where I have a Node model and each instance has an id primary key, a data field (say of type str), and an optional reference (read foreign key) to another node that we call its parent node (field name parent_id).

Ideally, every Node object should have a parent attribute, which will be None, if the node has no parent node; otherwise it will contain (a pointer to) the parent Node object.

And even better, every Node object should have a children attribute, which will be a list of Node objects that reference it as their parent.

The question is twofold:

  1. What is an elegant way to implement this with SQLModel?

  2. How would I create such node instances and insert them into the database?

Advertisement

Answer

The sqlmodel.Relationship function allows explicitly passing additional keyword-arguments to the sqlalchemy.orm.relationship constructor that is being called under the hood via the sa_relationship_kwargs parameter. This parameter expects a mapping of strings representing the SQLAlchemy parameter names to the values we want to pass through as arguments.

Since SQLAlchemy relationships provide the remote_side parameter for just such an occasion, we can leverage that directly to construct the self-referential pattern with minimal code. The documentation mentions this in passing, but crucially the remote_side value

may be passed as a Python-evaluable string when using Declarative.

This is exactly what we need. The only missing piece then is the proper use of the back_populates parameter and we can build the model like so:

from typing import Optional
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine


class Node(SQLModel, table=True):
    __tablename__ = 'node'  # just to be explicit

    id: Optional[int] = Field(default=None, primary_key=True)
    data: str
    parent_id: Optional[int] = Field(
        foreign_key='node.id',  # notice the lowercase "n" to refer to the database table name
        default=None,
        nullable=True
    )
    parent: Optional['Node'] = Relationship(
        back_populates='children',
        sa_relationship_kwargs=dict(
            remote_side='Node.id'  # notice the uppercase "N" to refer to this table class
        )
    )
    children: list['Node'] = Relationship(back_populates='parent')

# more code below...

Side note: We define id as optional as is customary with SQLModel to avoid being nagged by our IDE when we want to create an instance, for which the id will only be known, after we have added it to the database. The parent_id and parent attributes are obviously defined as optional because not every node needs to have a parent in our model.

To test that everything works the way we expect it to:

def test() -> None:
    # Initialize database & session:
    sqlite_file_name = 'database.db'
    sqlite_uri = f'sqlite:///{sqlite_file_name}'
    engine = create_engine(sqlite_uri, echo=True)
    SQLModel.metadata.drop_all(engine)
    SQLModel.metadata.create_all(engine)
    session = Session(engine)

    # Initialize nodes:
    root_node = Node(data='I am root')

    # Set the children's `parent` attributes;
    # the parent nodes' `children` lists are then set automatically:
    node_a = Node(parent=root_node, data='a')
    node_b = Node(parent=root_node, data='b')
    node_aa = Node(parent=node_a, data='aa')
    node_ab = Node(parent=node_a, data='ab')

    # Add to the parent node's `children` list;
    # the child node's `parent` attribute is then set automatically:
    node_ba = Node(data='ba')
    node_b.children.append(node_ba)

    # Commit to DB:
    session.add(root_node)
    session.commit()

    # Do some checks:
    assert root_node.children == [node_a, node_b]
    assert node_aa.parent.parent.children[1].parent is root_node
    assert node_ba.parent.data == 'b'
    assert all(n.data.startswith('a') for n in node_ab.parent.children)
    assert (node_ba.parent.parent.id == node_ba.parent.parent_id == root_node.id) 
           and isinstance(root_node.id, int)


if __name__ == '__main__':
    test()

All the assertions are satisfied and the test runs without a hitch.

Also, by using the echo=True switch for the database engine, we can verify in our log output that the table is created as we expected:

CREATE TABLE node (
    id INTEGER, 
    data VARCHAR NOT NULL, 
    parent_id INTEGER, 
    PRIMARY KEY (id), 
    FOREIGN KEY(parent_id) REFERENCES node (id)
)
Advertisement