Implementing Binary Connect in Fastai

Hi ! :smile:

I’ working on implementing quantized networks in Fastai. However, following a few experiments with it, I found that I can’t manage to replicate results from raw Pytorch impl. I’m mainly interested in replicating Binary Connect (Courbariaux et al. 2015 https://arxiv.org/abs/1511.00363) for now, and latter generalize the callback to any quantization scheme.

Following S. Gugger suggestion, I’m opening a thread to try and tackle this issue with the community ! :slight_smile:

Here is the bug report I made on the github repo :
Describe the bug
I’m trying to write up callbacks to run quantized networks using Fastai. I use as a baseline this Pytorch implementation that works really well : https://github.com/eghouti/BinaryConnect
However, using all the same parameters, optimizer, etc in Fastai my accuracy across a few runs struggle to reach 50% on CIFAR 10, while the baseline reach easily 90% in a rough dozen epochs. The loss saturate much quicker in Fastai.

Provide your installation details

=== Software === 
python        : 3.7.3
fastai        : 1.0.52
fastprogress  : 0.1.21
torch         : 1.0.1
nvidia driver : 396.37
torch cuda    : 9.0.176 / is available
torch cudnn   : 7301 / is enabled

=== Hardware === 
nvidia gpus   : 4
torch devices : 4
  - gpu0      : 16280MB | Tesla P100-SXM2-16GB
  - gpu1      : 16280MB | Tesla P100-SXM2-16GB
  - gpu2      : 16280MB | Tesla P100-SXM2-16GB
  - gpu3      : 16280MB | Tesla P100-SXM2-16GB

=== Environment === 
platform      : Linux-3.10.0-957.1.3.el7.x86_64-x86_64-with-centos-7.6.1810-Core
distro        : #1 SMP Thu Nov 29 14:49:43 UTC 2018
conda env     : base
python        : /users/henwood/anaconda3/bin/python
sys.path      : /users/henwood
/users/henwood/anaconda3/lib/python37.zip
/users/henwood/anaconda3/lib/python3.7
/users/henwood/anaconda3/lib/python3.7/lib-dynload
/users/henwood/anaconda3/lib/python3.7/site-packages

To Reproduce
I created a gist you can copy paste into a Jupyter Notebook : https://gist.github.com/sebastienwood/ac04d7296ecea9d3803fbd41038177f6

There is a part with the code from the baseline adapted for new version of Pytorch, and a part with Fastai.

Expected behavior
Across a few tries, there should at least be one case of the Fastai implementation converging to the same level of performance as the raw Pytorch. However it can’t manage to reach at least 50% accuracy. The loss is saturating much quicker.

Additional context
At the end of the 7th epoch, the raw Pytorch ususally has around 45% accuracy while the Fastai impl struggle to go past 25%.

Only difference there should be in the implementation are the transform : I used get_transforms() from Fastai while the Pytorch impl use reflected padding + random crop + random horizontal flip.

Using the proposed WRN or Fastai’s WRN doesn’t seem to bring any difference (same number of weights).

I’m a taker of any suggestion ! :slight_smile:

Update 1: wd used for the partial opt_func is overwritten by Learner, correcting it yield better stats at first but doesn’t improve loss saturation issue

Hi again ! :slight_smile:

After further debugging, the problem is now that convergence doesn’t match that of the baseline. Whereas one could hope for 90% accuracy in 30 epochs, my Fastai implementation stutters at 60%. Here is the updated gist that should have removed most of the issue of the first one :

To have a fair comparison, the baseline in raw Pytorch achieves 80% accuracy in a grand maximum 15 epochs. Any idea to solve this impl is welcome ! :slight_smile: