So I have the following (simplified) code
from typing import Iterable, List, Optional, overload, Literal, Union, Tuple, Any import sqlite3 @overload def query_db( query: str, params: Optional[Iterable], as_tuple: Literal[False] ) -> List[sqlite3.Row]: ... @overload def query_db( query: str, params: Optional[Iterable], as_tuple: Literal[True] ) -> List[Tuple[Any, ...]]: ... def query_db( query: str, params: Optional[Iterable] = None, as_tuple: bool = False ) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]: """Run a query against the given db. If params is not None, securely construct a query from the given query string and params. """ with sqlite3.connect("/dummy.sqlite") as con: if not as_tuple: con.row_factory = sqlite3.Row if params is None: rows = con.execute(query).fetchall() else: rows = con.execute(query, params).fetchall() return rows a = query_db("SELECT test_column FROM test_table") a[0]["test_column"]
which I don’t know how to get to typecheck.
If I don’t add the overloads mypy complains that I might be indexing into a tuple with a str
index.
The as_tuple
parameter defaults to false, so mypy should be able to infer that I’m using the first overload when not providing the second and the third argument to the function (as the actual implementation has default parameters).
However what actually happens is that mypy complains that none of the provided overloads match, since it thinks that I need to provide the last two arguments as well.
When I just copy paste the default arguments to each of the overloads, mypy complains that I can’t assign False
to as_tuple: Literal[True]
.
Is there an option to get this to typecheck the way it works at runtime? I really don’t want to modify the actual signature as the function is used widely throughout our tests.
Advertisement
Answer
If you let parameters in some of your overloads take defaults, then you don’t need as many overloads. You probably also want an extra overload for when you passed a boolean to as_tuple
:
from typing import Iterable, List, Optional, overload, Literal, Union, Tuple, Any import sqlite3 @overload def query_db( query: str, params: Optional[Iterable]=..., as_tuple: Literal[False]=... ) -> List[sqlite3.Row]: ... @overload def query_db( query: str, params: Optional[Iterable], as_tuple: Literal[True] ) -> List[Tuple[Any, ...]]: ... @overload def query_db( query: str, * , as_tuple: Literal[True] ) -> List[Tuple[Any, ...]]: ... @overload def query_db( query: str, params: Optional[Iterable]=..., as_tuple: bool=... ) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]: ... def query_db( query: str, params: Optional[Iterable] = None, as_tuple: bool = False ) -> Union[List[sqlite3.Row], List[Tuple[Any, ...]]]: """Run a query against the given db. If params is not None, securely construct a query from the given query string and params. """ with sqlite3.connect("/dummy.sqlite") as con: if not as_tuple: con.row_factory = sqlite3.Row if params is None: rows = con.execute(query).fetchall() else: rows = con.execute(query, params).fetchall() return rows query: str params: Optional[Iterable] as_tuple: bool reveal_type(query_db(query, params, as_tuple=True)) reveal_type(query_db(query, as_tuple=True)) reveal_type(query_db(query, params)) reveal_type(query_db(query)) reveal_type(query_db(query, params, as_tuple=False)) reveal_type(query_db(query, as_tuple=False)) reveal_type(query_db(query, params, as_tuple=as_tuple)) reveal_type(query_db(query, as_tuple=as_tuple))
Running this gives:
main.py:51: note: Revealed type is 'builtins.list[builtins.tuple[Any]]' main.py:52: note: Revealed type is 'builtins.list[builtins.tuple[Any]]' main.py:53: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]' main.py:54: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]' main.py:55: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]' main.py:56: note: Revealed type is 'builtins.list[sqlite3.dbapi2.Row]' main.py:57: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row], builtins.list[builtins.tuple[Any]]]' main.py:58: note: Revealed type is 'Union[builtins.list[sqlite3.dbapi2.Row], builtins.list[builtins.tuple[Any]]]'