I am trying to train a densenet model using the fast.ai library. I checked the documentation and I managed to make it work for resnet50. However, for densenet, it seems to be unable to find the module.
I tried to use arch=models.dn121
as stated by this forum. But I get the same error.
Can anyone please help?
Here is the code:
JavaScript
x
2
1
learn = create_cnn(data, base_arch=models.densenet201, metrics=accuracy, model_dir="/tmp/model/")
2
This is the error:
JavaScript
1
7
1
---------------------------------------------------------------------------
2
AttributeError Traceback (most recent call last)
3
<ipython-input-23-cb9ab3a79572> in <module>()
4
----> 1 learn = create_cnn(data, base_arch=models.densenet201, metrics=accuracy, model_dir="/tmp/model/")
5
6
AttributeError: module 'fastai.vision.models' has no attribute 'densenet201'
7
Advertisement
Answer
According to this post on the fast.ai forum, this is the solution to use densenet with fast.ai:
JavaScript
1
6
1
from torchvision.models import densenet121
2
3
def dn121(pre): return children(densenet121(pre))[0]
4
5
learn = create_cnn(data, dn121)
6