Skip to content
Advertisement

How to overload a function with default parameters in Python

So I have the following (simplified) code

from typing import Iterable, List, Optional, overload, Literal, Union, Tuple, Any
import sqlite3

@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[False]
) -> List[sqlite3.Row]:
    ...


@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...


def query_db(
    query: str, params: Optional[Iterable] = None, as_tuple: bool = False
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    """Run a query against the given db.

    If params is not None, securely construct a query from the given
    query string and params.
    """
    with sqlite3.connect("/dummy.sqlite") as con:
        if not as_tuple:
            con.row_factory = sqlite3.Row
        if params is None:
            rows = con.execute(query).fetchall()
        else:
            rows = con.execute(query, params).fetchall()
    return rows



a = query_db("SELECT test_column FROM test_table")
a[0]["test_column"]

which I don’t know how to get to typecheck.

If I don’t add the overloads mypy complains that I might be indexing into a tuple with a str index.

The as_tuple parameter defaults to false, so mypy should be able to infer that I’m using the first overload when not providing the second and the third argument to the function (as the actual implementation has default parameters).

However what actually happens is that mypy complains that none of the provided overloads match, since it thinks that I need to provide the last two arguments as well.

When I just copy paste the default arguments to each of the overloads, mypy complains that I can’t assign False to as_tuple: Literal[True].

Is there an option to get this to typecheck the way it works at runtime? I really don’t want to modify the actual signature as the function is used widely throughout our tests.

Advertisement

Answer

If you let parameters in some of your overloads take defaults, then you don’t need as many overloads. You probably also want an extra overload for when you passed a boolean to as_tuple:

from typing import Iterable, List, Optional, overload, Literal, Union, Tuple, Any
import sqlite3


@overload
def query_db(
    query: str, params: Optional[Iterable]=..., as_tuple: Literal[False]=...
) -> List[sqlite3.Row]:
    ...


@overload
def query_db(
    query: str, params: Optional[Iterable], as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...


@overload
def query_db(
    query: str, * , as_tuple: Literal[True]
) -> List[Tuple[Any, ...]]:
    ...
    

@overload
def query_db(
    query: str, params: Optional[Iterable]=..., as_tuple: bool=...
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    ...


def query_db(
    query: str, params: Optional[Iterable] = None, as_tuple: bool = False
) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]:
    """Run a query against the given db.

    If params is not None, securely construct a query from the given
    query string and params.
    """
    with sqlite3.connect("/dummy.sqlite") as con:
        if not as_tuple:
            con.row_factory = sqlite3.Row
        if params is None:
            rows = con.execute(query).fetchall()
        else:
            rows = con.execute(query, params).fetchall()
    return rows

query: str
params: Optional[Iterable]
as_tuple: bool

reveal_type(query_db(query, params, as_tuple=True))
reveal_type(query_db(query, as_tuple=True))
reveal_type(query_db(query, params))
reveal_type(query_db(query))
reveal_type(query_db(query, params, as_tuple=False))
reveal_type(query_db(query, as_tuple=False))
reveal_type(query_db(query, params, as_tuple=as_tuple))
reveal_type(query_db(query, as_tuple=as_tuple))

Running this gives:

main.py:51: note: Revealed type is 'builtins.list[builtins.tuple[Any]]'
main.py:52: note: Revealed type is 'builtins.list[builtins.tuple[Any]]'
main.py:53: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:54: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:55: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:56: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]'
main.py:57: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row], builtins.list[builtins.tuple[Any]]]'
main.py:58: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row], builtins.list[builtins.tuple[Any]]]'
Advertisement