Hi, I am doing segmantic segmentation with large class imbalances(5 classes). So I am passing in a weight array into my loss as in loss_fn = torch.nn.CrossEntropyLoss(weight=loss_weights)

Now since, it is very hard to assign these weights, I am trainning for 200 epochs and then setting the loss_weights then as trainable parameters. But when I do this I get the error :
RuntimeError: the derivative for ‘weight’ is not implemented

How can I get around this and any other suggestions to deal with such class imbalance?

You can always decrease the loss by decreasing the weights – if you set them to negative infinity, you would have negative infinity loss. So, even if you could backprop into the weights, you’re not going to learn anything meaningful without adding some kind of regularization. Even then, I suspect that all of the weight would get concentrated on the class w/ the lowest error rate, and you end up with a sort of degenerate classifier.

An approach that may work is treating the weights as a hyperparameter, and optimizing over the hyperparameter using eg. random search, hyperband, bayesian optimization, etc. This is very expensive obviously, but it would be a neat result if you could show that this kind of search can lead to a better performing model.