An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.
An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case.
It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples.
FocalLoss
is an nn.Module
and behaves very much like nn.CrossEntropyLoss()
i.e.
reduction
and ignore_index
params, and(N, C)
as well as K-dimensional inputs of shape (N, C, d1, d2, ..., dK)
.Example usage
focal_loss = FocalLoss(alpha, gamma)
..
np, targets = batch
out = model(inp)
oss = focal_loss(out, targets)
This repo supports importing modules through torch.hub
. FocalLoss
can be easily imported into your code via, for example:
focal_loss = torch.hub.load(
'adeelh/pytorch-multi-class-focal-loss',
model='FocalLoss',
alpha=torch.tensor([.75, .25]),
gamma=2,
reduction='mean',
force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)
Or:
focal_loss = torch.hub.load(
'adeelh/pytorch-multi-class-focal-loss',
model='focal_loss',
alpha=[.75, .25],
gamma=2,
reduction='mean',
device='cpu',
dtype=torch.float32,
force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)