Hello I was going through this notebook in from course_v3 part-2 pascal.ipynb
Here loss for RetinaNet is defined like so:
class RetinaNetFocalLoss(nn.Module):
def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales:Collection[float]=None,
ratios:Collection[float]=None, reg_loss:LossFunction=F.smooth_l1_loss):
super().__init__()
self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss
self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)])
self.ratios = ifnone(ratios, [1/2,1,2])
def _change_anchors(self, sizes:Sizes) -> bool:
if not hasattr(self, 'sizes'): return True
for sz1, sz2 in zip(self.sizes, sizes):
if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True
return False
def _create_anchors(self, sizes:Sizes, devicet:torch.device):
self.sizes = sizes
self.anchors = create_anchors(sizes, self.ratios, self.scales).to(device)
def _unpad(self, bbox_tgt, clas_tgt):
i = torch.min(torch.nonzero(clas_tgt-self.pad_idx))
return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx
def _focal_loss(self, clas_pred, clas_tgt):
encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))
ps = torch.sigmoid(clas_pred.detach())
weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps
alphas = (1-encoded_tgt) * self.alpha + encoded_tgt * (1-self.alpha)
weights.pow_(self.gamma).mul_(alphas)
clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')
return clas_loss
def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)
matches = match_anchors(self.anchors, bbox_tgt)
bbox_mask = matches>=0
if bbox_mask.sum() != 0:
bbox_pred = bbox_pred[bbox_mask]
bbox_tgt = bbox_tgt[matches[bbox_mask]]
bb_loss = self.reg_loss(bbox_pred, bbox_to_activ(bbox_tgt, self.anchors[bbox_mask]))
else: bb_loss = 0.
matches.add_(1)
clas_tgt = clas_tgt + 1
clas_mask = matches>=0
clas_pred = clas_pred[clas_mask]
clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])
clas_tgt = clas_tgt[matches[clas_mask]]
return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.)
def forward(self, output, bbox_tgts, clas_tgts):
clas_preds, bbox_preds, sizes = output
if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)
n_classes = clas_preds.size(2)
return sum([self._one_loss(cp, bp, ct, bt)
for (cp, bp, ct, bt) in zip(clas_preds, bbox_preds, clas_tgts, bbox_tgts)])/clas_tgts.size(0)
I having trouble grasping what the _unpad
function does. What does i
signify here ?