I try to use a Trie data structure for some coding problem. For each node in a trie, you typically put a a list of reference of its children. So, I thought about using defaultdict to create a default empty trie node if some children does not exists in a lookup. However, I don’t know how to use defaultdict to refer to the class that enclose it.
I tried two methods, which were both failed. The following is what I tried.
from dataclasses import dataclass from collections import defaultdict @dataclass class TrieNode(): is_word = False children = defaultdict("TrieNode")
The code above produce
Traceback (most recent call last): File "<stdin>", line 2, in <module> File "<stdin>", line 4, in TrieNode TypeError: first argument must be callable or None
@dataclass class TrieNode(): is_word = False children = defaultdict(TrieNode)
The above will produce
Traceback (most recent call last): File "<stdin>", line 2, in <module> File "<stdin>", line 4, in TrieNode NameError: name 'TrieNode' is not defined
My question is about how do you use defaultdict
to implement this elegantly.
Thank you very much in advance.
Advertisement
Answer
Your second approach with children = defaultdict(TrieNode)
is closer to correct, since defaultdict
needs the constructor for TrieNode
in order to populate it with TrieNode
s – the other approach passes a string where a callable is expected. Your problem is being caused by the fact that you are accessing the name TrieNode
before the class has finished being created, giving the NameError
. To fix this you can use children = defaultdict(lambda: TrieNode())
. This way, the name TrieNode
is only looked up when the lambda function is called.
However, for a trie, you want every node to have its own dictionary of children, and with this approach, modifying the children dictionary for one node will modify it for them all because all their dictionaries would be the same object. I would reccomend you use dataclass.field
to create a new dictionary for each TrieNode
, like so:
from dataclasses import dataclass, field from collections import defaultdict @dataclass class TrieNode(): is_word = False children : 'TrieNode' = field(default_factory=lambda: defaultdict(TrieNode))