Hi,I want to convert my pytorch code to fastai.The key point are databunch and loss_func.
Data:
My pytorch dataloader return three images,tensor or numpy array,and the ‘disaprity’ is the target.
Model:My model take ‘left’ and ‘right’ as inputs,and return multiple images (tensor) for calculation loss.
Note:Unlike most examples, my model is multi-input and multi-output.And my target is defined in the input,So how do I build a databunch that contains multiple input values and get the target as input in loss_func.
The model and loss_func() may like this:
class Pytorch_IresNet(nn.Module):
```
def forward(self, left,right):
return predict_final, r_res2_predict, r_res1_predict, r_res0
loss = model_loss(predict_final, disp_predict, r_res2_predict, r_res1_predict, r_res0, target)
Can anyone provide me with some convenient methods?
And for the loss function look at the source code and see how it handles the predictions and targets. Your loss function’s first parameter is just the model output and the rest of the parameters are the individual targets your dataloader is sending. In case of multiple targets you will have something like:
def loss_func(preds, target_1, target_2, target_3, ...):
# calculate the loss
return loss
Also you need to first create a dataloader from the dataset using the default dataloader from pytorch. To create the databunch just use:
Thank you for your reply,Your answer solved my problem perfectly.I solved the problem a few hours ago, and the method is very similar to what you said. Except:
#the code I used
DataBunch.create(train_dataset, test_dataset, bs=bs)
My model runs well,In order to better analyze my model, I want to record my output like tensorboardX (mainly target tensor/images).Can I record my output tensor in fastai and convert it to images, just like the logger.add_image() function in tensorboardX.If so,where can I find the relevant documentation?
You need to create a custom callback here and pass it into the fit function. Look at the docs of fastai. It’s easy. If you don’t want to do that, an easy hack is to just save the images inside the loss function forward method from the output manually but I would prefer a callback because it’s the right way to do it.