Hey everyone, I’ve been playing around in order to ensemble CNN Learners and came up with the following.
class Ensemble:
def __init__(self, dl, models : dict, vocab : list = [0, 1]):
self.models = models
self.vocab = vocab
self.dl = dl
print(f'vocab: {self.vocab}')
for name, model in models.items():
print(f'loaded: {name}')
def calc_probas(self, item):
probas = []
for _, model in self.models.items():
_, _, p = model.predict(item)
probas.append(p)
probas = torch.stack(probas, dim=0)
return probas
def predict(self, item):
probas = self.calc_probas(item)
mean, std = probas.mean(axis=0), probas.std(axis=0)
return self.vocab[mean.argmax()], mean, std
def get_preds(self, dl=None, with_input=True, with_loss=True, with_decoded=True, act=None):
if dl is None: dl = self.model_list[0].dls[1]
predictions = []
losses = []
for name, model in self.models.items():
print(f'Getting predictions from {name}')
inputs, preds, targs, decoded, loss = model.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None)
predictions.append(preds)
losses.append(loss)
preds = torch.stack(predictions).mean(0)
decoded = preds.argmax(1)
return inputs, preds, targs, decoded, torch.stack(losses, dim=1).mean(1)
def calc_metrics(self, metrics : dict):
res = {}
_,_, targs, decoded,_ = self.get_preds(self.dl)
for name, metric in metrics.items():
res[name] = metric(decoded, targs)
return res
Here’s a link to a colab with an example showing how I ensemble 3 cnn learners. Let me know your thoughts/suggestions and hope this is useful!