According to the pytorch doc of nn.BCEWithLogitsLoss, pos_weight is an optional argument a that takes the weight of positive examples. I don't fully understand the statement "pos_weight > 1 increases recall and pos_weight < 1 increases precision" in that page. How do you guys understand this statement?
The binary cross-entropy with logits loss (nn.BCEWithLogitsLoss, equivalent to F.binary_cross_entropy_with_logits) is a sigmoid layer (nn.Sigmoid) followed with a binary cross-entropy loss (nn.BCELoss). The general case assumes you are in a multi-label classification task i.e. a single input can be labeled with multiple classes. One common sub-case is to have a single class: the binary classification task. If you define q as your tensor of predicted classes and p the ground-truth [0,1] corresponding to the true probabilities for each class.
The explicit formulation for the binary cross-entropy would be:
z = torch.sigmoid(q)
loss = -(w_p*p*torch.log(z) + (1-p)*torch.log(1-z))
introducing the w_p, the weight associated with the true label for each class. Read this post for more details on the weighting scheme used by the BCELoss.
For a given class:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
Then if w_p > 1, it increases the weight on the positive classification (classifying as true). This will tend to increase false positives (FP), thus decreasing the precision. Similarly if if w_p < 1, we are decreasing the weight on the true class which means it will tend to increase false negatives (FN), which decreases recall.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With