Skip to content
Advertisement

Using densenet with fastai

I am trying to train a densenet model using the fast.ai library. I checked the documentation and I managed to make it work for resnet50. However, for densenet, it seems to be unable to find the module.

I tried to use arch=models.dn121 as stated by this forum. But I get the same error.

Can anyone please help?

Here is the code:

learn = create_cnn(data, base_arch=models.densenet201, metrics=accuracy, model_dir="/tmp/model/")

This is the error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-23-cb9ab3a79572> in <module>()
----> 1 learn = create_cnn(data, base_arch=models.densenet201, metrics=accuracy, model_dir="/tmp/model/")

AttributeError: module 'fastai.vision.models' has no attribute 'densenet201'

Advertisement

Answer

According to this post on the fast.ai forum, this is the solution to use densenet with fast.ai:

from torchvision.models import densenet121

def dn121(pre): return children(densenet121(pre))[0]

learn = create_cnn(data, dn121)
User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement