Truncating and moving models from one learner to another learner?

(xnet) #1

Well, this is confusing.

I’ve been training a U-Net. Now I want to cut off the decoder part and use the encoder as the backbone for a classifier.

This seems to pull out the model, and I can pull out the encoder part of the model

learn.model
newmodel = children(learn.model[:4])

I assumed that newmodel will be the encoder part with the weights trained up to that point in my code.

But when I call a new cnn_learner it gives an TypeError: unhashable type ‘list’

new_learn = cnn_learner(new_data, newmodel, path=savedir

Can anyone help? When I call models.resnet50, it also does not give a print out of layers. How do I convert a learner.model to be usable by another learner?
`

0 Likes

(xnet) #2

Some updates

newmodel = nn.Sequential(*list(children(learn.model[:4]))

Not sure if it’s a step in the right direction, but calling cnn_learner still fails. The problem is that cnn_learner calls create_cnn_model, that calls create_body, which calls

model = arch(pretrained)

It fails on this step. I did a models.resnet18(True), and it returned the list of layers like my newmodel. Calling newmodel(True) obviously errors out.

Does anyone know how to obtain a models.resnet18 type object from a learner object that has already been trained? Thanks!

0 Likes

(Karl) #3

You’re trying to create a learner through the create_cnn function, which isn’t designed to take a model as input. models.resnet18 isn’t a Pytorch model. It’s a function from torchvision that returns a pytorch model. The create_cnn function calls that model.

If you look in the code for cnn_learner, you see the learner is actually created via learn = Learner(data, model, **kwargs), where model is what was created after calling the torchvision model and adding a custom head.

You just need to add a new linear head to your encoder, then pass it directly to the Learner class.

0 Likes

(williamholding) #4

The hash() is a built-in python method, used to return a unique number . This can be applied to any user-defined object which won’t get changed once initialized. This property is used mainly in dictionary keys .

TypeError: unhashable type: ‘list’ usually means that you are trying to use a list as an hash argument. This means that when you try to hash an unhashable object it will result an error. For ex. when you use a list as a key in the dictionary , this cannot be done because lists can’t be hashed. The standard way to solve this issue is to cast a list to a tuple .

0 Likes