I’ve simplified it a bit and it seems to be doing better (I’ll update with results). Here’s the conv_twist layer, replacing each 3x3 convolution. I don’t know if I can explain more briefly than the code:
class conv_twist(nn.Module): # replacing 3x3 Conv2d
def __init__(self, ni, nf, stride=1):
super(conv_twist, self).__init__()
self.conv = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False)
self.convx = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False)
self.convy = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False)
self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3)) / 2
self.convy.weight.data = self.convx.weight.transpose(2,3).flip(2)
# self.radii = nn.Parameter(torch.Tensor(nf), requires_grad=True)
self.center_x = nn.Parameter(torch.Tensor(nf), requires_grad=True)
self.center_y = nn.Parameter(torch.Tensor(nf), requires_grad=True)
# self.radii.data.uniform_(0.3, 0.7)
self.center_x.data.uniform_(-0.7, 0.7)
self.center_y.data.uniform_(-0.7, 0.7)
def forward(self, x):
self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3)) / 2 # make convx a first-order operator by symmetrizing it
self.convy.weight.data = (self.convy.weight - self.convy.weight.flip(2).flip(3)) / 2
# self.convy.weight.data = self.convx.weight.transpose(2,3).flip(2)) # make convy a 90 degree rotation of convx
x1 = self.conv(x)
_, c, h, w = x1.size()
XX = torch.from_numpy(np.indices((1,h,w))[2]*2/w).type(x.dtype).to(x.device) - self.center_x.view(-1,1,1)
YY = torch.from_numpy(np.indices((1,h,w))[1]*2/h).type(x.dtype).to(x.device) - self.center_y.view(-1,1,1)
# mask = ramp_func((XX**2+YY**2)/(self.radii.type(x.dtype).to(x.device).view(-1,1,1)**2))
return x1 + (XX * self.convx(x) + YY * self.convy(x)) # * mask
Update: imagewoof2
Size (px) | Epochs | model | mixup | Accuracy | # Runs |
---|---|---|---|---|---|
128 | 5 | (Leaderboard) | 73.37% | 5, mean | |
128 | 5 | RMS | 0 | 68.54% | 5, mean |
128 | 5 | RMS + twist | 0 | 70.95% | 5, mean |
128 | 20 | (Leaderboard) | 85.52% | 5, mean | |
128 | 20 | RMS | 0 | 84.62% | 5, mean |
128 | 20 | RMS + twist | 0 | 85.24% | 5, mean |
128 | 80 | (Leaderboard) | 87.20% | 1 | |
128 | 80 | RMS + twist | 0.2 | 87.81% | 1 |
128 | 80 | RMS + twist | 0.5 | 88.52% | 1 |
128 | 200 | (Leaderboard) | 87.20% | 1 | |
128 | 200 | RMS + twist | 0.2 | 88.70% | 1 |
256 | 200 | (Leaderboard) | 90.38% | 1 | |
256 | 200 | RMS + twist | 0.2 | 91.52% | 1 |
imagenette2
Size (px) | Epochs | model | mixup | Accuracy | # Runs |
---|---|---|---|---|---|
256 | 200 | (Leaderboard) | 95.11% | 1 | |
256 | 200 | RMS + twist | 0.5 | 95.87% | 1 |
@a_yasyrev, if you could help test with your ResNet trick + MaxBlurPool, that would be very nice.