Skip to content
Advertisement

Python `collections.defaultdict` for the same class

I try to use a Trie data structure for some coding problem. For each node in a trie, you typically put a a list of reference of its children. So, I thought about using defaultdict to create a default empty trie node if some children does not exists in a lookup. However, I don’t know how to use defaultdict to refer to the class that enclose it.

I tried two methods, which were both failed. The following is what I tried.

from dataclasses import dataclass
from collections import defaultdict

@dataclass   
class TrieNode():
    is_word = False
    children = defaultdict("TrieNode")

The code above produce

Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "<stdin>", line 4, in TrieNode
TypeError: first argument must be callable or None
@dataclass   
class TrieNode():
    is_word = False
    children = defaultdict(TrieNode)

The above will produce

Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "<stdin>", line 4, in TrieNode
NameError: name 'TrieNode' is not defined

My question is about how do you use defaultdict to implement this elegantly. Thank you very much in advance.

Advertisement

Answer

Your second approach with children = defaultdict(TrieNode) is closer to correct, since defaultdict needs the constructor for TrieNode in order to populate it with TrieNodes – the other approach passes a string where a callable is expected. Your problem is being caused by the fact that you are accessing the name TrieNode before the class has finished being created, giving the NameError. To fix this you can use children = defaultdict(lambda: TrieNode()). This way, the name TrieNode is only looked up when the lambda function is called.

However, for a trie, you want every node to have its own dictionary of children, and with this approach, modifying the children dictionary for one node will modify it for them all because all their dictionaries would be the same object. I would reccomend you use dataclass.field to create a new dictionary for each TrieNode, like so:

from dataclasses import dataclass, field
from collections import defaultdict

@dataclass   
class TrieNode():
    is_word = False
    children : 'TrieNode' = field(default_factory=lambda: defaultdict(TrieNode))
User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement