I’m using some custom fastai code to run a BYOL model, and a custom loss is needed. See: https://github.com/KeremTurgutlu/self_supervised/blob/682c1e5367c2779f3ed0f8a0dc15a5abd27d5a0f/nbs/20-byol.ipynb
I realized that learn.summary() does not work with custom loss functions, and I had similar issues before. This wasn’t the case with fastai v2 a while ago (before the official release).
class BYOL(Callback):
def __init__(self, T=0.99, debug=True, size=224, **aug_kwargs):
self.T, self.debug = T, debug
self.aug1 = get_aug_pipe(size, **aug_kwargs)
self.aug2 = get_aug_pipe(size, **aug_kwargs)
def before_fit(self):
"Create target model"
self.target_model = copy.deepcopy(self.learn.model).to(self.dls.device)
self.T_sched = SchedCos(self.T, 1)
def before_batch(self):
"Generate 2 views of the same image and calculate target projections for these views"
if self.debug: print(f"self.x[0]: {self.x[0]}")
v1,v2 = self.aug1(self.x), self.aug2(self.x.clone())
self.learn.xb = (v1,v2)
if self.debug:
print(f"v1[0]: {v1[0]}\nv2[0]: {v2[0]}")
self.show_one()
assert not torch.equal(*self.learn.xb)
with torch.no_grad():
z1 = self.target_model.projector(self.target_model.encoder(v1))
z2 = self.target_model.projector(self.target_model.encoder(v2))
self.learn.yb = (z1,z2)
def after_step(self):
"Update target model and T"
self.T = self.T_sched(self.pct_train)
with torch.no_grad():
for param_k, param_q in zip(self.target_model.parameters(), self.model.parameters()):
param_k.data = param_k.data * self.T + param_q.data * (1. - self.T)
def show_one(self):
b1 = self.aug1.decode(to_detach(self.learn.xb[0]))
b2 = self.aug2.decode(to_detach(self.learn.xb[1]))
i = np.random.choice(len(b1))
show_images([b1[i],b2[i]], nrows=1, ncols=2)
def _mse_loss(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def symmetric_mse_loss(pred, *yb):
(q1,q2),z1,z2 = pred,*yb
return (_mse_loss(q1,z2) + _mse_loss(q2,z1)).mean()
byol_loss = symmetric_mse_loss
After running this
learn = Learner(dls, model, byol_loss,
cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None)]
Calling learn.summary() errors out… BUT if I do learn.fit(1) and call learn.summary() afterwards, no error!
Looking at the error output it starts from:
2079
2080 with torch.no_grad():
-> 2081 z1 = self.target_model.projector(self.target_model.encoder(v1))
2082 z2 = self.target_model.projector(self.target_model.encoder(v2))
2083 self.learn.yb = (z1,z2)
and ends in
python3.7/site-packages/torch/nn/functional.py in _verify_batch_size(size)
2035 size_prods *= size[i + 2]
2036 if size_prods == 1:
-> 2037 raise ValueError(‘Expected more than 1 value per channel when training, got input size {}’.format(size))
2038
2039
> ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])
The thing I don’t understand is why does learn.summary() crash with that ValueError if I don’t call learn.fit() before learn.summary()?
My workaround ultimately was:
learn = Learner(dls, model, byol_loss, cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None), ShortEpochCallback(0.001)]
learn.fit(1)
learn = Learner(dls, model, byol_loss, cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None)]
learn.summary()
Because this below does not work!
learn = Learner(dls, model, byol_loss, cbs=[BYOL(T=0.99, size=28, debug=False, color=False, stats=None)]
learn.summary()