G Mlp Gpt Save

GPT, but made only out of MLPs

Project README

GPT - gMLP

This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will contain a variant that does gMLP for local sliding windows. The hope is to be able to stretch a single GPU to be able to train context lengths of 4096 and above efficiently and well.

You can also add the "tiny" attention (as described in the paper) with the attn_dim keyword argument, which corresponds to the dimension of the single head (64 is recommended). You can pass in a tuple to customize different dimension per layer.

Install

$ pip install g-mlp-gpt

Usage

import torch
from g_mlp_gpt import gMLPGPT

model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    depth = 4,
    seq_len = 1024,
    window = (128, 256, 512, 1024) # window sizes for each depth
)

x = torch.randint(0, 20000, (1, 1000))
logits = model(x) # (1, 1000, 20000)

16k context length

import torch
from g_mlp_gpt import gMLPGPT

model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    seq_len = 16384,
    reversible = True,    # reversible networks
    act = nn.Tanh(),      # tanh activation for spatial gating
    depth = 12,
    window = (
        128,
        128,
        256,
        512,
        1024,
        1024,
        (2048, 2),    # window size of 2048, axial of 2
        (2048, 2),
        (4096, 4),
        (4096, 4),
        (8192, 8),    # window size of 8192, axial of 8
        (8192, 8)
    )
).cuda()

x = torch.randint(0, 20000, (1, 16384)).cuda()
logits = model(x) # (1, 16384, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Open Source Agenda is not affiliated with "G Mlp Gpt" Project. README Source: lucidrains/g-mlp-gpt
Stars
86
Open Issues
0
Last Commit
2 years ago
License
MIT

Open Source Agenda Badge

Open Source Agenda Rating