Grabbing the Batch Size on a Loss Function?

So quick question, I am using a loss function that inherits from the batch_size used in the data. How do I pass it in? I know I can implement it as a callback but I’m not quite there yet to convert it over.


def total_loss(out, label):
  bs =

Can I pass it into the loss function? But how do I know then that the model will actually pass it through if I do:

def total_loss(out, label, batch_size)

Thanks for any clarification you guys could add for me!

In general, the (mini)batch size is found as the first dimension of out.

To understand the scope of variables used in functions, you will need to study “closures” in Python.

HTH, Malcolm

By mini-batch do you mean if I specify for instance bs=12 when I make my databunch, that size?

mini-batch size == bs in fastai.

It’s easy to see whether bs ends up in out’s dimension 0. Just try it: run a batch through your model and check the dimensions of out.

If not, understanding “partial function” will help you.

1 Like

Got it. I’ll look into partial functions! But it makes sense! Thank you very much!