Learn.summary() issues with custom losses

I’m using some custom fastai code to run a BYOL model, and a custom loss is needed. See: https://github.com/KeremTurgutlu/self_supervised/blob/682c1e5367c2779f3ed0f8a0dc15a5abd27d5a0f/nbs/20-byol.ipynb

I realized that learn.summary() does not work with custom loss functions, and I had similar issues before. This wasn’t the case with fastai v2 a while ago (before the official release).

class BYOL(Callback):
    def __init__(self, T=0.99, debug=True, size=224, **aug_kwargs):        
        self.T, self.debug = T, debug
        self.aug1 = get_aug_pipe(size, **aug_kwargs)
        self.aug2 = get_aug_pipe(size, **aug_kwargs)


    def before_fit(self):
        "Create target model"
        self.target_model = copy.deepcopy(self.learn.model).to(self.dls.device)        
        self.T_sched = SchedCos(self.T, 1) 
  
        
    def before_batch(self):
        "Generate 2 views of the same image and calculate target projections for these views"
        if self.debug: print(f"self.x[0]: {self.x[0]}")
        
        v1,v2 = self.aug1(self.x), self.aug2(self.x.clone())
        self.learn.xb = (v1,v2)
        
        if self.debug:
            print(f"v1[0]: {v1[0]}\nv2[0]: {v2[0]}")
            self.show_one()
            assert not torch.equal(*self.learn.xb)

        with torch.no_grad():
            z1 = self.target_model.projector(self.target_model.encoder(v1))
            z2 = self.target_model.projector(self.target_model.encoder(v2))
            self.learn.yb = (z1,z2)


    def after_step(self):
        "Update target model and T"
        self.T = self.T_sched(self.pct_train)
        with torch.no_grad():
            for param_k, param_q in zip(self.target_model.parameters(), self.model.parameters()):
                param_k.data = param_k.data * self.T + param_q.data * (1. - self.T)
          

    def show_one(self):
        b1 = self.aug1.decode(to_detach(self.learn.xb[0]))
        b2 = self.aug2.decode(to_detach(self.learn.xb[1]))
        i = np.random.choice(len(b1))
        show_images([b1[i],b2[i]], nrows=1, ncols=2)

def _mse_loss(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

def symmetric_mse_loss(pred, *yb):
    (q1,q2),z1,z2 = pred,*yb
    return (_mse_loss(q1,z2) + _mse_loss(q2,z1)).mean()
byol_loss = symmetric_mse_loss

After running this

learn = Learner(dls, model, byol_loss,
                cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None)]

Calling learn.summary() errors out… BUT if I do learn.fit(1) and call learn.summary() afterwards, no error!

Looking at the error output it starts from:

2079
2080 with torch.no_grad():
-> 2081 z1 = self.target_model.projector(self.target_model.encoder(v1))
2082 z2 = self.target_model.projector(self.target_model.encoder(v2))
2083 self.learn.yb = (z1,z2)

and ends in

python3.7/site-packages/torch/nn/functional.py in _verify_batch_size(size)
2035 size_prods *= size[i + 2]
2036 if size_prods == 1:
-> 2037 raise ValueError(‘Expected more than 1 value per channel when training, got input size {}’.format(size))
2038
2039
> ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])

The thing I don’t understand is why does learn.summary() crash with that ValueError if I don’t call learn.fit() before learn.summary()?

My workaround ultimately was:

learn = Learner(dls, model, byol_loss, cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None), ShortEpochCallback(0.001)]
learn.fit(1)
learn = Learner(dls, model, byol_loss, cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None)]
learn.summary()

Because this below does not work!

learn = Learner(dls, model, byol_loss, cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None)]
learn.summary()