Finetune Detr Save

Fine-tune Facebook's DETR (DEtection TRansformer) on Colaboratory.

Project README

Finetune DETR

The goal of this Google Colab notebook is to fine-tune Facebook's DETR (DEtection TRansformer).

With pre-trained DETR -> With finetuned DETR

From left to right: results obtained with pre-trained DETR, and after fine-tuning on the balloon dataset.

Usage

  • Acquire a dataset, e.g. the the balloon dataset,
  • Convert the dataset to the COCO format,
  • Run finetune_detr.ipynb to fine-tune DETR on this dataset. Open In Colab
  • Alternatively, run finetune_detectron2.ipynb to rely on the detectron2 wrapper. Open In Colab

NB: Fine-tuning is recommended if your dataset has less than 10k images. Otherwise, training from scratch would be an option.

Data

DETR will be fine-tuned on a tiny dataset: the balloon dataset. We refer to it as the custom dataset.

There are 61 images in the training set, and 13 images in the validation set.

We expect the directory structure to be the following:

path/to/coco/
├ annotations/  # JSON annotations
│  ├ annotations/custom_train.json
│  └ annotations/custom_val.json
├ train2017/    # training images
└ val2017/      # validation images

NB: if you are confused about the number of classes, check this Github issue.

Metrics

Typical metrics to monitor, partially shown in this notebook, include:

  • the Average Precision (AP), which is the primary challenge metric for the COCO dataset,
  • losses (total loss, classification loss, l1 bbox distance loss, GIoU loss),
  • errors (cardinality error, class error).

As mentioned in the paper, there are 3 components to the matching cost and to the total loss:

  • classification loss,
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
    """Classification loss (NLL)
    targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
    """
    [...]
    loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
    losses = {'loss_ce': loss_ce}
  • l1 bounding box distance loss,
def loss_boxes(self, outputs, targets, indices, num_boxes):
    """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
       targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
       The target boxes are expected in format (center_x, center_y, w, h),normalized by the image
       size.
    """
    [...]
    loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
    losses['loss_bbox'] = loss_bbox.sum() / num_boxes
    loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
        box_ops.box_cxcywh_to_xyxy(src_boxes),
        box_ops.box_cxcywh_to_xyxy(target_boxes)))
    losses['loss_giou'] = loss_giou.sum() / num_boxes

Moreover, there are two errors:

  • cardinality error,
def loss_cardinality(self, outputs, targets, indices, num_boxes):
    """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty
    boxes. This is not really a loss, it is intended for logging purposes only. It doesn't
    propagate gradients
    """
    [...]
    # Count the number of predictions that are NOT "no-object" (which is the last class)
    card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
    card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
    losses = {'cardinality_error': card_err}
    # TODO this should probably be a separate loss, not hacked in this one here
    losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]

where accuracy is:

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""

Results

You should obtain acceptable results with 10 epochs, which require a few minutes of fine-tuning.

Out of curiosity, I have over-finetuned the model for 300 epochs (close to 3 hours). Here are:

All of the validation results are shown in view_balloon_validation.ipynb. Open In Colab

References

Open Source Agenda is not affiliated with "Finetune Detr" Project. README Source: woctezuma/finetune-detr

Open Source Agenda Badge

Open Source Agenda Rating