pytorch implementation of Optimization as a Model for Few-shot Learning
Pytorch implementation of Optimization as a Model for Few-shot Learning in ICLR 2017 (Oral)
- data/
- miniImagenet/
- train/
- n01532829/
- n0153282900000005.jpg
- ...
- n01558993/
- ...
- val/
- n01855672/
- ...
- test/
- ...
- main.py
- ...
scripts/train_5s_5c.sh
, make sure --data-root
is properly setFor 5-shot, 5-class training, run
bash scripts/train_5s_5c.sh
Hyper-parameters are referred to the author's repo.
For 5-shot, 5-class evaluation, run (remember to change --resume
and --seed
arguments)
bash scripts/eval_5s_5c.sh
seed | train episodes | val episodes | val acc mean | val acc std | test episodes | test acc mean | test acc std |
---|---|---|---|---|---|---|---|
719 | 41000 | 100 | 59.08 | 9.9 | 100 | 56.59 | 8.4 |
- | - | - | - | - | 250 | 57.85 | 8.6 |
- | - | - | - | - | 600 | 57.76 | 8.6 |
53 | 44000 | 100 | 58.04 | 9.1 | 100 | 57.85 | 7.7 |
- | - | - | - | - | 250 | 57.83 | 8.3 |
- | - | - | - | - | 600 | 58.14 | 8.5 |
learner_w_grad
functions as a regular model, get gradients and loss as inputs to meta learner.learner_wo_grad
constructs the graph for meta learner:
learner_wo_grad
are replaced by cI
output by meta learner.nn.Parameters
in this model are casted to torch.Tensor
to connect the graph to meta learner.copy_flat_params
: we only need the parameter values and keep the original grad_fn
.transfer_params
: we want the values as well as the grad_fn
(from cI
to learner_wo_grad
).
.data.copy_
v.s. clone()
-> the latter retains all the properties of a tensor including grad_fn
.load_state_dict
is used (from learner_w_grad
to learner_wo_grad
).nn.Parameters
to torch.Tensor
inspired from here)