NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at tourdeml.github.io/blog/
Fixes of replace_conv and replaces BatchNorm2d
with Identity
.
Paper: https://arxiv.org/abs/2102.06171.pdf Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
pip3 install git+https://github.com/vballoli/nfnets-pytorch
Use WSConv2d
like any other torch.nn.Conv2d
.
import torch
from torch import nn
from nfnets import WSConv2d
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
Similarly, use SGD_AGC
like torch.optim.SGD
import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)
import torch
from torch import nn
from torchvision.models import resnet18
from nfnets import replace_conv
model = resnet18()
replace_conv(model)
Find the docs at readthedocs
To cite the original paper, use:
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}