Learner for Tensorflow


(Bryan Heffernan) #1

I would like to share a Learner class that has most of the functionality of a Fastai Pytorch Learner except it is ported to use Tensorflow/Keras models, optimizers, and loss functions. It uses Fastai’s DataBunch which means all the awesome data augmentation dataloaders work. It also uses the callback system which means Fastai’s learing rate / momentum scheduler, lr_finder, etc work seamlessly. A link to the dev notebook can be found here:

This is supposed to be a drop in to the current Fastai library and requires no changes to preexisting code. Also, the classes are structured so that they are as close to Fastai’s basic_train.py classes as possible. I will explain how some of this is implemented. First up is the TFLearner class. The constructor requires a tf.keras.Model, or a Keras functional API model (docs: https://www.tensorflow.org/api_docs/python/tf/keras/models/Model), a DataBunch object, a tf.train optimizer such as tf.train.AdamOptimizer, a tf.losses function such as tf.losses.sparse_softmax_cross_entropy, and Tensorflow/Keras compatible metrics. All the other parameters are the same as a regular Learner. (Pay attention to the loss function and metrics parameter orders for custom functions). Below is a list of the currently working features, and features that can be added in the future.

The fit function calls an external fit function which iterates through the dataloaders, and then calls the loss_batch function per batch. These functions have almost an identical layout as the fastai Learner. The only hackery is in the loss_batch function, where the Pytorch batch tensor is put on to the cpu, converted to numpy, then put back on the GPU as a tf.Tensor. This part is inefficient, but it allows for use of fastai’s transforms. I am not sure what the exact performance impact is, but I do know models can be trained in reasonable time. Next the tf.Tensor is run through the model, loss is calculated, gradients are calculated with tf.GradientTape, and the optimizer step is ran.

Next up is the TFOptimWrapper class. This class is the Tensorflow optimizer equivalent of the Fastai OptimWrapper. Learning rate schedules in Tensorflow are annoying, so this class takes care of creating a wrapper equivalent to OptimWrapper to allow Fastai callbacks to play with the learning rates. Currently this class only handles momentum and learning rate. Other hyper parameters can be added in the future.

List of currently supported features:

  • tf.keras.Model models
  • keras api models
  • layer freezing
  • model saving/loading
  • DataBunch data
  • fit
  • lr_find
  • fit_one_cycle
  • learning rate schedules
  • momentum schedules
  • recorder callback - graphing

To be added in the future:

  • discrim learning rate
  • gradient clipping
  • more hyper parameters (alpha/beta)
  • weight decay
  • true weight decay
  • batchnorm wd option
  • freeze batchnorm layers

You may be wondering what the purpose of this TFLearner is. Why not just use Pytorch models that Fastai already supports? Don’t get me wrong, I love Pytorch, but Tensorflow/Keras is still the most popular framework in the community. Sometimes I want to take someone else’s Keras model and test it with hyper parameter schedules and data augmentation. Or to test out a paper that has already been implemented in Keras.

If acceptable by the Fastai lead developers, I would like to add this Learner to Fastai. I think it would be a good addition to the framework because it’s a drop-in solution that doesn’t need any external class modification. And it would allow people who don’t know Pytorch to utilize Fastai’s core features such as hyper parameter schedules, dataloader transformations, learning rate finder, etc.

I spent a lot of time testing this and making sure it doesn’t break anything in the Fastai library, but let me know of any bugs. I also learned a lot about Tensorflow eager execution and the Fastai code base from this. So thanks to all the Fastai developers who made fastai_v1 so great!


#2

Isn’t that possible with ONNX export/import? https://github.com/onnx/tensorflow-onnx