When declaring a model in pytorch, having a model class member variable declared and initialized mysteriously prevents it from being populated in the constructor. Is this expected? If so, why?
Testing code below, with example models with a component
member variable.
The initialization value of the component (e.g. None
, a number or a Tensor) does not change the behaviour.
import torch class Lin1(torch.nn.Module): def __init__(self): super(Lin1, self).__init__() self.component = torch.nn.Linear(100,200) class Lin2(torch.nn.Module): component = None def __init__(self): super(Lin2, self).__init__() self.component = torch.nn.Linear(100,200) # instantiate and check member component for cl in [Lin1, Lin2]: model = cl() print("nModel:", cl) print("Component:", model.component) print("It's None?: ", model.component is None)
Model: <class '__main__.Lin1'> Component: Linear(in_features=100, out_features=200, bias=True) It's None?: False Model: <class '__main__.Lin2'> Component: None It's None?: True
Advertisement
Answer
This happens because nn.Module
overwrites __getattr__
, and it would only work as you expect if component
was not in Lin2.__dict__
(nor in Lin2().__dict__
). Since component
is a class attribute, it is in Lin2.__dict__
and will be returned as it should.
When you write self.x = nn.Linear(...)
or any other nn.Module
(or even nn.Buffer
or nn.Parameter
), x
is actually registered in a dictionary called _modules
(or _buffers
, etc.) In this way, when you ask for self.component
, if component
is already in the __dict__
of the class or the instance, Python will not call the custom nn.Module
‘s __getattr__()
.
You can check the source-code of __getattr__
from nn.Module
here. A similar discussion was done here. There was also a discussion about changing from __getattr__
to __getattribute__
in PyTorch, but as of now, this is a wontfix issue.