I started to write this post when my very first fastai based pipeline with custom Dataset, DataLoader and Learner was trained on background. I went through much of pain to do that and I want to help those of you who go the same way at the first time.
If you want to create with fastai something that was not covered in Jeremy’s awesome DL course, you’ll need to understand a fastai structure first, so my first suggest is to look at this this mindmap created by @shaun1 (link to his post).
This will give you some big picture but this is not all.
For pipeline you need:
Dataset
- Base fastai class -
BaseDataset
- Knows how to open your data by index
- Example: feed him with list of images filenames and list of labels and create
get_x
func that returns image andget_y
func that return label imnp.array
format
DataLoader
- Base fastai class -
DataLoader
- Knows how to read data from
Dataset
and create batches from it
ModelData
- Base fastai class -
ModelData
- Contains your DataLoaders, path to your data, transformation etc
Model
- Base class -
nn.Module
- Pure pytorch model of whatever architecture you want
Loss function
- The functions that calculate loss between Model output and target
Optimizer
- Your favorite optimizer (Adam, SGD, RMSProp etc)
Stepper
- Base fastai class -
Stepper
- Makes optimizer steps during training process
Sampler
- Base classs - torch
Sampler
- Sample your data somehow during training (ex. - balancing classes)
Learner
- Base fastai class -
Learner
- Knows how to learn your Model with given ModelData, Loss function, Optimizer, Stepper and Sampler
Simple steps to create your custom pipeline
NOTE: You may have not to create a new class, existing class may be ok for you, so check this first.
- Create a
Dataset
, make sure it returns what you want - Create a
DataLoader
for your dataset - Create a
ModelData
- Create a
Model
and try it on some sample data from yourDataLoader
- Grab some suitable
Optimizer
,Loss function
and create aLearner
with them (you can also create your customStepper
andSampler
for this) - Try to call
learner.fit(...)
- Get some errors and get mad
- Chill out, put
ipdb.set_trace()
everywhere you can (for my experience the most useful parts are beginning of your modelsforward()
functions and beginning of loss function) - Debug untill it works
I hope this little guide help some of you. If you have something useful to add, write it to me or in this topic, I will add it to this post.