Skip connections in tabular data

While looking at Tensorflow playground I thought it was interesting how feeding input^2 and sin(input) can change the learning.

I was wondering how come TabularModel does not use skip connections which seem to help ResNet blocks, (which seemed to be somewhat related :slight_smile: )

Here is my half-baked implementation which creates a skip connecting on the first layer.
I tried running it against few datasets and the results did not seem to improve :frowning:
p.s: After reading the docs on tabular, I tried testing my model on the URLs.ADULT_SAMPLE I noticed that running the same model but degenerated (learn = tabular_learner(data, layers=[*1*], emb_szs={'native-country': 10}, metrics=accuracy) would yield the same accuracy. So this might not be a good benchmark.

from fastai.torch_core import *
from fastai.layers import *

class SkipModel(nn.Module):
    "Basic model for tabular data."
    def __init__(self, emb_szs:ListSizes, n_cont:int, out_sz:int, layers:Collection[int], ps:Collection[float]=None,
                 emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):
        super().__init__()
        ps = ifnone(ps, [0]*len(layers))
        ps = listify(ps, layers)
        self.embeds = nn.ModuleList([embedding(ni, nf) for ni,nf in emb_szs])
        self.emb_drop = nn.Dropout(emb_drop)
        self.bn_cont = nn.BatchNorm1d(n_cont)
        n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range
        sizes = self.get_sizes(layers, out_sz)
        actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
        layers = []
        n_concated = 0
        for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):
            # layers += bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)
            if i>0:
                layers+=bn_drop_lin(n_in+n_concated, n_out, bn=use_bn and i!=0, p=dp, actn=act)
                n_concated=0
            else:
                layers.append(SequentialEx(*bn_drop_lin(n_in+n_concated, n_out, bn=use_bn and i!=0, p=dp, actn=act), MergeLayer(dense=True)))
                n_concated=n_in+n_concated
        if bn_final: layers.append(nn.BatchNorm1d(sizes[-1]+n_last_out))
        self.layers = nn.Sequential(*layers)

    def get_sizes(self, layers, out_sz):
        return [self.n_emb + self.n_cont] + layers + [out_sz]

    def forward(self, x_cat:Tensor, x_cont:Tensor) -> Tensor:
        if self.n_emb != 0:
            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
            x = torch.cat(x, 1)
            x = self.emb_drop(x)
        if self.n_cont != 0:
            x_cont = self.bn_cont(x_cont)
            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
        x = self.layers(x)
        if self.y_range is not None:
            x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]
        return x
5 Likes

Wow, I had the exact same thought when I was using the tabular model. “I wonder if this would perform with skip connections?” Thanks for dropping this here!