Reproducing "How to train your ResNet" using fastai

I chose to reproduce the process described in the article series “How to train your ResNet” (Part 1) as my toy project. It has turned out to be unexpectedly more difficult than I thought. If anyone wants to join me on this journey, you are more than welcome to!

The main problem I ran into and haven’t solved yet when reproducing baseline is that training is a lot slower than it should be. I am doing most of my testing in google colab (it is a toy project after all), but I tested it using the recommended p3.2xlarge aws instance and one epoch took about 28 seconds (instead of 7 seconds).

So my next step is to measure which part of the pipeline is so slow and why. Has anyone experience with instrumenting the fastai v1 library for this purpose? Any help would be much appreciated.

3 Likes

What batch size were you using? And did you use mixed precision? :slight_smile: The latter can have an effect on training time (especially language models)

Thanks for the fast reply. I am going to post a gist of the notebook (or repo if that is better?) shortly. I am using the same batch size they used for baseline, namely 128.

1 Like

Here is a gist of the notebook: https://gist.github.com/davidpfahler/6f2be99f1c54aaa25aa4523ecd443676

1 Like

I quickly profiled a one-epoch run like this:

with torch.autograd.profiler.profile() as prof:
  learn.fit_one_cycle(1, max_lr=lr)
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

which yielded the following table:

------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                                  Self CPU total %   Self CPU total      CPU total %        CPU total     CPU time avg     CUDA total %       CUDA total    CUDA time avg  Number of Calls
------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
_local_scalar_dense                            62.02%          79.461s           62.02%          79.461s          3.704ms              NaN          0.000us          0.000us            21454
to                                             14.31%          18.338s           14.42%          18.476s        703.222us              NaN          0.000us          0.000us            26274
cudnn_convolution_backward                      7.71%           9.876s            7.71%           9.876s          1.266ms              NaN          0.000us          0.000us             7800
pin_memory                                      7.66%           9.819s            7.68%           9.833s         10.483ms              NaN          0.000us          0.000us              938
cudnn_convolution                               2.50%           3.203s            2.50%           3.203s        341.474us              NaN          0.000us          0.000us             9380
add                                             0.60%        770.467ms            0.60%        770.467ms         26.833us              NaN          0.000us          0.000us            28713
add_                                            0.60%        762.582ms            0.60%        762.582ms         15.761us              NaN          0.000us          0.000us            48384
cudnn_batch_norm                                0.59%        756.170ms            0.59%        756.170ms        100.769us              NaN          0.000us          0.000us             7504
torch::autograd::AccumulateGrad                 0.54%        690.668ms            0.54%        690.668ms         32.795us              NaN          0.000us          0.000us            21060
mul_                                            0.54%        689.457ms            0.54%        689.457ms         16.390us              NaN          0.000us          0.000us            42066
sum                                             0.49%        625.302ms            0.49%        625.302ms         29.152us              NaN          0.000us          0.000us            21450
cudnn_batch_norm_backward                       0.34%        429.789ms            0.34%        429.789ms         68.876us              NaN          0.000us          0.000us             6240
relu_                                           0.25%        319.800ms            0.25%        319.800ms         42.617us              NaN          0.000us          0.000us             7504
div_                                            0.23%        296.166ms            0.23%        296.166ms         14.063us              NaN          0.000us          0.000us            21060
zero_                                           0.21%        272.455ms            0.21%        272.455ms          6.469us              NaN          0.000us          0.000us            42120
threshold_backward                              0.18%        231.264ms            0.18%        231.264ms         37.062us              NaN          0.000us          0.000us             6240
empty                                           0.11%        145.885ms            0.11%        145.885ms         10.601us              NaN          0.000us          0.000us            13761
_convolution                                    0.10%        127.225ms            2.61%           3.348s        356.903us              NaN          0.000us          0.000us             9380
_batch_norm_impl_index                          0.09%        117.206ms            0.71%        915.070ms        121.944us              NaN          0.000us          0.000us             7504
CudnnBatchNormBackward                          0.09%        111.281ms            0.44%        557.354ms         89.320us              NaN          0.000us          0.000us             6240
mul                                             0.07%         95.310ms            0.07%         95.310ms         58.187us              NaN          0.000us          0.000us             1638
CudnnConvolutionBackward                        0.06%         83.172ms            7.77%           9.959s          1.277ms              NaN          0.000us          0.000us             7800
contiguous                                      0.06%         75.477ms            0.06%         75.477ms          1.420us              NaN          0.000us          0.000us            53140
cat                                             0.05%         59.615ms            0.05%         59.615ms        127.111us              NaN          0.000us          0.000us              469
item                                            0.04%         55.755ms           62.06%          79.517s          3.706ms              NaN          0.000us          0.000us            21454
ReluBackward1                                   0.04%         53.854ms            0.22%        285.119ms         45.692us              NaN          0.000us          0.000us             6240
batch_norm                                      0.04%         51.705ms            0.75%        966.775ms        128.835us              NaN          0.000us          0.000us             7504
addmm                                           0.04%         49.332ms            0.04%         49.332ms        105.185us              NaN          0.000us          0.000us              469
max_pool2d_with_indices                         0.03%         43.607ms            0.03%         43.607ms         92.978us              NaN          0.000us          0.000us              469
conv2d                                          0.03%         38.564ms            2.66%           3.408s        363.334us              NaN          0.000us          0.000us             9380
mm                                              0.03%         33.100ms            0.03%         33.100ms         42.436us              NaN          0.000us          0.000us              780
div                                             0.03%         32.895ms            0.03%         32.895ms         38.250us              NaN          0.000us          0.000us              860
nll_loss_forward                                0.02%         30.601ms            0.02%         30.601ms         65.247us              NaN          0.000us          0.000us              469
avg_pool2d_backward                             0.02%         29.856ms            0.02%         29.856ms         76.553us              NaN          0.000us          0.000us              390
nll_loss_backward                               0.02%         27.793ms            0.02%         27.793ms         71.264us              NaN          0.000us          0.000us              390
view                                            0.02%         26.096ms            0.02%         26.096ms         13.348us              NaN          0.000us          0.000us             1955
sub                                             0.02%         25.801ms            0.02%         25.801ms         55.013us              NaN          0.000us          0.000us              469
max_pool2d_with_indices_backward                0.02%         25.034ms            0.02%         25.034ms         64.189us              NaN          0.000us          0.000us              390
_log_softmax                                    0.02%         24.982ms            0.02%         24.982ms         53.266us              NaN          0.000us          0.000us              469
avg_pool2d                                      0.02%         22.046ms            0.02%         22.046ms         47.007us              NaN          0.000us          0.000us              469
convolution                                     0.02%         21.756ms            2.63%           3.370s        359.222us              NaN          0.000us          0.000us             9380
_log_softmax_backward_data                      0.01%         15.687ms            0.01%         15.687ms         40.223us              NaN          0.000us          0.000us              390
detach_                                         0.01%         15.225ms            0.01%         15.225ms          0.354us              NaN          0.000us          0.000us            43058
unsqueeze                                       0.01%         13.424ms            0.01%         13.424ms          7.156us              NaN          0.000us          0.000us             1876
unsigned short                                  0.01%         13.200ms            0.01%         13.200ms          6.505us              NaN          0.000us          0.000us             2029
mean                                            0.01%         11.071ms            0.01%         11.071ms        140.135us              NaN          0.000us          0.000us               79
_th_set_                                        0.01%         10.898ms            0.01%         10.898ms          5.809us              NaN          0.000us          0.000us             1876
max                                             0.01%          7.375ms            0.01%          7.375ms         93.349us              NaN          0.000us          0.000us               79
AddmmBackward                                   0.00%          6.033ms            0.04%         44.866ms        115.042us              NaN          0.000us          0.000us              390
transpose                                       0.00%          5.780ms            0.00%          5.780ms          4.352us              NaN          0.000us          0.000us             1328
NllLossBackward                                 0.00%          5.472ms            0.03%         33.265ms         85.296us              NaN          0.000us          0.000us              390
MulBackward0                                    0.00%          5.245ms            0.02%         30.771ms         78.900us              NaN          0.000us          0.000us              390
set_                                            0.00%          4.439ms            0.01%         15.338ms          8.176us              NaN          0.000us          0.000us             1876
_th_eq                                          0.00%          4.265ms            0.00%          4.265ms         53.991us              NaN          0.000us          0.000us               79
reshape                                         0.00%          4.146ms            0.01%          8.226ms         10.546us              NaN          0.000us          0.000us              780
as_strided                                      0.00%          4.079ms            0.00%          4.079ms          5.230us              NaN          0.000us          0.000us              780
nll_loss                                        0.00%          4.011ms            0.03%         34.612ms         73.799us              NaN          0.000us          0.000us              469
max_pool2d                                      0.00%          3.553ms            0.04%         47.159ms        100.553us              NaN          0.000us          0.000us              469
log_softmax                                     0.00%          3.450ms            0.02%         28.432ms         60.623us              NaN          0.000us          0.000us              469
AddBackward0                                    0.00%          3.408ms            0.00%          3.408ms          1.092us              NaN          0.000us          0.000us             3120
clone                                           0.00%          3.165ms            0.00%          3.165ms         29.303us              NaN          0.000us          0.000us              108
detach                                          0.00%          3.038ms            0.00%          3.038ms          2.905us              NaN          0.000us          0.000us             1046
MaxPool2DWithIndicesBackward                    0.00%          3.001ms            0.02%         28.035ms         71.884us              NaN          0.000us          0.000us              390
LogSoftmaxBackward                              0.00%          2.943ms            0.01%         18.630ms         47.768us              NaN          0.000us          0.000us              390
torch::autograd::CopyBackwards                  0.00%          2.793ms            0.02%         22.582ms         57.904us              NaN          0.000us          0.000us              390
AvgPool2DBackward                               0.00%          2.288ms            0.03%         32.144ms         82.421us              NaN          0.000us          0.000us              390
slice                                           0.00%          2.214ms            0.00%          2.214ms          2.838us              NaN          0.000us          0.000us              780
narrow                                          0.00%          2.021ms            0.00%          4.235ms          5.429us              NaN          0.000us          0.000us              780
ViewBackward                                    0.00%          1.879ms            0.01%         10.105ms         12.956us              NaN          0.000us          0.000us              780
CatBackward                                     0.00%          1.847ms            0.00%          6.082ms         15.594us              NaN          0.000us          0.000us              390
argmax                                          0.00%        885.046us            0.01%          8.260ms        104.552us              NaN          0.000us          0.000us               79
TBackward                                       0.00%        832.800us            0.00%          3.042ms          7.799us              NaN          0.000us          0.000us              390
TransposeBackward0                              0.00%        764.504us            0.00%          1.997ms          5.121us              NaN          0.000us          0.000us              390
eq                                              0.00%        585.620us            0.00%          4.851ms         61.404us              NaN          0.000us          0.000us               79
torch::autograd::GraphRoot                      0.00%        408.692us            0.00%        408.692us          1.048us              NaN          0.000us          0.000us              390
is_floating_point                               0.00%        383.389us            0.00%        383.389us          0.856us              NaN          0.000us          0.000us              448
stack                                           0.00%        225.215us            0.00%        225.215us        225.215us              NaN          0.000us          0.000us                1
random_                                         0.00%         59.977us            0.00%         59.977us         29.989us              NaN          0.000us          0.000us                2
is_complex                                      0.00%          1.687us            0.00%          1.687us          0.844us              NaN          0.000us          0.000us                2
------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Self CPU time total: 128.119s
CUDA time total: 0.000us

Naturally, I’m now looking into what could cause the long time spent on _local_scalar_dense. This issue comment suggests that it might be caused by .item() calls on a tensors.

2 Likes

This is an important project. This model is SoTA in terms of speed for CIFAR10 (94% accuracy).

I also had trouble reproducing the results in fastai but didn’t try much.

If you succeed, you might find ways to make fastai faster and also will provide us with more familiar code that we can all use easily to try and beat the dawnbench result.

Hello all!

I have been doing a pretty deep dive into this work as I am interested in decreasing the time of training for another project of mine. So let me know if you have any questions.

Since it wasn’t linked, here is the repository. Here is the final post where they got it down to 34 seconds for training on the CIFAR10 with accuracy over 94% (for some reason the post was taken down from myrtle.ai).

Several things to note. Of course it is going to be slower on Google Colab as the experiments were run on V100 GPU, not the T4 GPU offered by Colab. Second, myrtle.ai used their own source code and did not use a lot of the defaults offered in PyTorch (for example, instead of directly using DataLoaders, replaced with their own class called Batches which builds on top of DataLoaders). This may be one reason why it is so slow compared to fastai.

Fastai also has its own defaults. For example, fastai does not put the batchnorm on fp16, which you will see in the myrtle.ai series actually does improve speed. But there may be other cases that were are defaults fastai has that might decrease speed.

So it is important to go through the source code of both packages and make sure they are working identically.

Looking at your profile, it seems a lot of the functions taking time have more to do with transferring data between CPU and GPU. Make sure you are not needlessly transferring back and forth between CPU and GPU.

2 Likes

Thank you for the encouragement! Just so that others reading this don’t get confused: I was referring to Part 1 of the article series, where they reimplemented the then SoTA of 341 seconds. My goal is to eventually get to the SoTA by going through the article series one-by-one.

1 Like

Absolutely. That’s why I tested it on the recommended p3.2xlarge aws instance and it was still 4x slower.

Totally, that’s my goal. I thought the best way to get there was to start at Part 1: Baseline in their article series and reproduce that first. But that’s where I am already seeding a 4x slow down. Also, I do want to write this in fastai as much as possible, so if something in fastai is slowing this down, I want to find out what it is and then replace as little as possible of it with my own code.

Do you have any experience in profiling? I don’t and so it’s pretty hard for me to find out what is causing this to slow down so much. It is possible of course that there is simply a bug in my code which means I didn’t correctly reproduce their code, but if so I can’t find it.

I do have some experience with profiling but with profiling code using CUDA. I have used the Spyder IDE profiler in the past, and it’s pretty nice but that wouldn’t work for you if you are working with Jupyter Notebook.

But based on the results you shared, I noticed two things. I am unsure about _local_scalar_dense but I know that to and pin_memory are functions involved in transferring between cpu and gpu. fastai sets pin_memory=True as a default again, but setting it to True is technically supposed to speed it up. But try without it and see if that’s the problem.

EDIT: Never mind, the myrtle.ai code also uses pin_memory=True. But still the transferring between CPU and GPU seems to be the major problem.

Thanks again. I will keep looking into this and report back with results.

1 Like

I found this:

Since _local_scalar_dense is mainly used in the .item() function it may be because of CUDA synchronization errors and not an actual bottleneck in the code.

Yes, _local_scalar_dense looks related to synchronisation when returning values from GPU to CPU as in .item() and I think also some other functions. It can be an actual bottleneck though. While in synthetic results it is not indicative of what is actually taking time, it can slow down training. I gather it effectively drains the queue of GPU work by making the CPU wait for outstanding GPU work to complete thus slowing down overall training.
For instance I’m working on some code running in a callback hook and using .item() (or similar) results in significant slowdown of overall training. I saw times go from 8seconds an epoch without the code to 20secs best case (minimising all such calls) and up to over a minute in some of my attempts.

Also, pin_memory moves tensors from standard memory to unpaged memory (so the physical memory address won’t be changed by the OS before it is used). One use is to allow for non-blocking transfer of data between CPU and GPU, though I’m not sure PyTorch uses this by default (you have to use multiple streams and explicitly use particular methods such as Tensor.copy_(non-blocking=True) and are then responsible for synchronisation). I think it is also used in PyTorch for copying between the main process and worker processes which may be where it is being used.

You may want to also look at the output of the python profiler as it will give you a little more context on where things are happening as it includes time in child calls which PyTorch doesn’t seem to and shows all functions rather than using the NVTX ranges (custom names you can assign for the nVidia profiler tools to use).
You may also find it useful to try adding your own NVTX ranges which will let you add custom groupings to the profiler result to see where stuff is happening. You can use torch.autograd._push_range and torch.autograd._pop_range as in:

_push_range('My Range')
something()
_pop_range()

I haven’t tried this when profiling torch modules which create their own ranges but have used it with basic torch code and ranges display in torch.autograd.profiler output.

2 Likes

PyTorch does not use pin_memory=True as the default. However, fastai and myrtle.ai do.

I also agree that a traditional python profiler might be better for this case. If .itemi() calls truly are the problem, a traditional profiler might help figure which fastai/PyTorch functions are calling it.

Yeah, I was saying I don’t think asynchronous copies are the defaults. You are quite right that the default for pin_memory on a data loader is false, meaning it does not return pinned tensors by default. From a quick search this does seem to be the only place pining is used, workers use another shared memory mechanism.
The docs also suggest that pin_memory will allow faster synchronous copies as well as being a prerequisite for asynchronous copies which I don’t think will happen by default in PyTorch given the need for multiple streams and explicit synchronisation.

@ilovescience @TomB Thanks, both of you! This is a all a little bit hard to understand for me as I have never done python profiling before, but I think I get the gist (please correct me where I am wrong): Basically .item() may or may not itself cause the slow down (the timings could also just be attributed to .item() because it is a synchronization point as described in this post). To find out what fastai/PyTorch functions are causing this slow down (e.g. by calling .item()) I should look at the output of the python profiler.

Which profilier do you recommend I should use: cProfile, profile or hotshot?

Yeah, pretty much, .item() seeming to eat up lots of time isn’t necessarily bad (it certainly doesn’t mean you should try to optimise it). But having a lot of .item() calls in the training loop can significantly slow down training as it means the GPU is sitting idle until you give it more work, which should be fairly fast, a few ms, but if happening a lot will slow everything down so these are good things to look at.

The python profiler might help here, though it isn’t great. You can profile in jupyter with the %prun magic (or add %%prun to the top of a cell to profile all code in the cell). For instance profiling some code I’m working on I have a cell with:

%%prun -s cumulative -l 40 
code to profile

To sort by cumulative time and show the top 40 results. Which gives:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      3/1    0.000    0.000   86.046   86.046 {built-in method builtins.exec}
        1    0.000    0.000   86.046   86.046 <string>:1(<module>)
        1    0.000    0.000   86.042   86.042 train.py:14(fit_one_cycle)
        1    0.000    0.000   86.042   86.042 basic_train.py:193(fit)
        1    0.000    0.000   86.041   86.041 basic_train.py:85(fit)
       15    0.000    0.000   82.222    5.481 basic_train.py:20(loss_batch)
  1710/30    0.008    0.000   76.246    2.542 module.py:537(__call__)
   525/15    0.002    0.000   76.231    5.082 container.py:90(forward)
      690    0.003    0.000   58.933    0.085 hooks.py:18(hook_fn)
      690    0.012    0.000   58.927    0.085 tensorboard.py:156(hook)
       92    0.004    0.000   58.913    0.640 tensorboard.py:90(make_histogram)
       92    0.032    0.000   58.850    0.640 exp_histogram.py:25(exp_histogram)
       92   58.432    0.635   58.432    0.635 {method 'min' of 'torch._C._TensorBase' objects}
      120    0.003    0.000   47.644    0.397 xresnet.py:47(forward)

From which I can figure out that the time is coming from calling torch.min and see the path of calls that led to this (all the calls with basically the same cumtime). Of course it may be harder if everything is not due to slowdown at one point as here.

Oh, and be aware running with %prun is significantly slower, it took 80s to profile what would normally be an 8 second epoch, so you want to minimise the dataset size if possible.

@davidpfahler Have you been able to obtain profile of your code? Please update this thread with your results. I would love to follow your progress, as I am also quite interested in optimizing training code in fastai.

Thanks for your interest. I am doing what I can and will definitely keep you updated here. I have since tried to reproduce the baseline with fastaiv2 which is even slower, but apparently for different reasons. Before I continue with my inquiry, I am debating whether I should focus completely on fastai v2 as it is the future of fastai and still very malleable. If you have an opinion on that question either way, I would like to hear it. I just feels like spending time on optimizing v1 will be a waste in a few months time.

I am not sure either. I can provide reasons for both cases though. The reason for using fastai v1 is that the API of fastai v1 will not change at all but the fastai v2 API is changing very often, so your code might also be slightly out-of-date if written now compared to when fastai v2 releases in a few months. However, I am interested by the fact that you say fastai v2 is slower for different reasons. If you are able to figure that out, that would be a good contribution to the fastai v2 library.