from torch import nn
weights = [0.4, 1]
class_weights=torch.FloatTensor(weights).cuda()
learn.crit = nn.CrossEntropyLoss(weight=class_weights)
from torch import nn
weights = [0.4, 1]
class_weights=torch.FloatTensor(weights).cuda()
learn.crit = nn.CrossEntropyLoss(weight=class_weights)