def custom_resnet(pretrained=False):
model = torchvision.models.resnet50(pretrained=pretrained)
model = torch.nn.DataParallel(model).cuda()
checkpoint = model_zoo.load_url(x)
model.load_state_dict(checkpoint["state_dict"])
print(type(model)) # <class 'torch.nn.parallel.data_parallel.DataParallel'>
all_layers = list(model.children())
print(all_layers) # [ResNet((conv1): ... (fc): Linear(in_features=2048, out_features=1000, bias=True))]
return nn.Sequential(*all_layers[0], *all_layers[1:])
Hey guys! The above is my custom pretrained resnet
and I am passing to the cnn_leaner
as follows.
learn = cnn_learner(src,
base_arch=custom_resnet,
pretrained=True,
metrics=error_rate)
I am getting the error message of type object argument after * must be an iterable, not ResNet
. I guess I will have to convert the ResNet
to Sequential
?
Any help is appreciated. Thank you.
Edit: if i only did return nn.Sequential(model)
, the error thrown is
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
StopIteration: