How to set up Fastai for Few Shot/Meta Learning?

I am trying to implement a Meta Learning paper with fastai(namely Meta-SGD). I am using the Omniglot dataset.

The meta model is defined by the initial weights of the model and learning rates of every weight. Basically the parameter updates do not take place in the direction of gradients as in normal SGD, because the learning rate for every parameter can be independently tuned by the meta model. I need a way in the on batch end function to use custom learning rates for each parameter. Is this possible in fastai/pytorch? I am guessing that I have to write a custom optimizer.

Secondly, I need a sampler for picking N classes at a time(this is a single task). For each task, I need a databunch that has K images of each class in the train set and 20-K images in the test(validation) set. How do I go about doing this? Is it possible to write a meta databunch whose train_ds is a list of databunches? Or would it make more sense to just split the classes beforehand and create a list of databunches with a different databunch being used every iteration?

Thanks

2 Likes