I don’t have an exact plan as of now. (Also I’m home for holidays and my family does not appreciate me spending time on my laptop while I’m here so it’s hard to really get started on anything)
Although I do have at least one very specific thing in mind. And I think it’s pretty cool.
So, I wanted to test out a modification to BatchNorm layers after doing Part-2 of the course. I was looking for a function that can somehow scale all it’s input to between -1 and 1 without any parameters. Naturally, I turned to activation functions such as
Tanh, etc. After some experimentation in the
07_batchnorm notebook itself, what finally seemed to work reliably was applying
Tanh followed by the usual scaling and shifting with
adds in BatchNorm (Essentially, I replaced the normalization part of a
BatchNorm layer with a
Tanh). The final results looked pretty (very) close to the usual BatchNorm at least for the network in that Notebook
97.35% vs 97.66% accuracy for 5 epochs each.
Since this looked encouraging, I wanted to try this modification for all the usual networks (Resnets, Inception, etc.) and see if this generalizes with datasets that are more serious than
But I didn’t want to have to write new model definitions for all these architectures as I figured that I’ll be needing to write a whole lot more definitions this way and that will slow me down.
So I went ahead and just tried setting the
bn* modules in the factory PyTorch
resnet modules to a
THIS DID NOT WORK.
It seemed to break some of the skip-connection logic. I realized that the same issue will probably crop up in other networks too.
After a bit of fiddling around, I found this bit of code to work:
def recursive_getattr(obj:nn.Module, name:str):
""" getattr for nested attributes with `.` in their names """
sequence = name.split('.')
if len(sequence) == 0 or name == '': return obj
for attr in sequence: obj = getattr(obj, attr)
def recursive_setattr(obj:nn.Module, name:str, new_attr):
""" setattr for nested attributes with `.` in their names """
sequence = name.split('.')
obj = recursive_getattr(obj, '.'.join(sequence[:-1]))
setattr(obj, sequence[-1], new_attr)
def modify_network(m:nn.Module, replace_func:Callable, condition:Callable=None)->nn.Module:
""" modifies `m` by replacing each module that satisfies `condition`
with replace_func(module) """
if condition is None:
# If `replace_func` has a condition in-built, it should return
# the passed module as it is. This is to prevent a recursion issue.
condition = (lambda x: not x == replace_func(x))
modules = list(m.named_modules())
if len(modules) == 1:
for name, module in modules:
recursive_setattr(m, name, replace_func(module))
After this I can do what I originally intended with very little code:
if isinstance(m, nn.BatchNorm2d):
nf = m.num_features
m = TanHNorm(nf) # A module I defined elsewhere
m = models.resnet18(pretrained=False)
This gives me back a working
resnet18 with all the BatchNorm layers replaced with a suitable
tanh layer. I have yet to test it with other kinds of architectures.
In essence, this allows me to take models created by anyone with whatever logic they might have used in their script, and then just modify them as needed with very little code. This can be handy when trying out new activation functions, particular modifications to
cnn layers or just in general adding/modifying arbitrary parts of existing networks with little effort. (Can this be added to Fastai?)
It does work with FastAI’s (V1)
Also, the graphs for the
tanh experiment look like this:
As opposed to the usual
It’s not as smooth, but I think this is worth exploring further. The core advantage of using this approach is that there are no variance or mean parameters or calculations (since that’s generally what causes a lot of problems with batch sizes and generalization). By getting rid of those parameters, we may get rid of those problems too. (I’m still learning though. Please, please correct me if this sounds naive or wrong)
I want to start running experiments with resnets on
Imagenette using this approach as soon as possible but right now I’m all caught up in preparing to go back to college. Hopefully, I’ll have something tangible to show for soon enough.
As far as writing the blog is concerned, I was hoping to discover new things as I go along ripping apart some of the core functions in
nn.Module just like I did with
matplotlib. A significant part of that was covered in Part-2 of the course itself, so that might help here too.
(I’m thinking of exploring other libraries as well -
Jax: Numpy with Autograd, XLA and GPUs,
streamlit: easily creating data science GUI tools/apps)