Implementation of deep implicit attention in PyTorch
Experimental implementation of deep implicit attention in PyTorch.
Summary: Using deep equilibrium models to implicitly solve a set of self-consistent mean-field equations of a random Ising model implements attention as a collective response 🤗 and provides insight into the transformer architecture, connecting it to mean-field theory, message-passing algorithms, and Boltzmann machines.
Blog post: Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms
Transformer architectures can be understood as particular approximations of a parametrized mean-field description of a vector Ising model being probed by incoming data x_i
:
z_i = sum_j J_ij z_j - f(z_i) + x_i
where f
is a neural network acting on every vector z_i
and the z_i
are solved for iteratively.
DEQMLPMixerAttention
A deep equilibrium version of MLP-Mixer transformer attention (https://arxiv.org/abs/2105.02723, https://arxiv.org/abs/2105.01601):
z_i = g({z_j}) - f(z_i) + x_i
where g
is an MLP acting across the sequence dimension instead of
the feature dimension (so across patches). The network f
parametrizes the self-correction term and acts across the feature dimension (so individually on every sequence).
Compared to a vanilla softmax attention transformer module (see below), the
sum over couplings has been "amortized" and parametrized by an MLP.
The fixed-point variables z_i
are also fed straight into the
feed-forward self-correction term. One could feed the naive mean-field update g({z_j}) + x_i
instead to fully mimic the residual connection in the explicit MLP-Mixer architecture.
DEQVanillaSoftmaxAttention
A deep equilibrium version of vanilla softmax transformer attention (https://arxiv.org/abs/1706.03762):
z_i = sum_j J_ij z_j - f(z_i) + x_i
where
J_ij = [softmax(X W_Q W_K^T X^T / sqrt(dim))]_ij
Transformer attention takes the couplings J_ij
to depend on x_i
parametrically and considers the fixed-point equation above as a single-step update equation. Compared to the explicit vanilla softmax attention transformer module, there's no values and the fixed-point variables z_i
are fed straight into the feed-forward self-correction term.
DEQMeanFieldAttention
Fast and neural deep implicit attention as introduced in https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/.
Schematically, the fixed-point mean-field equations including the Onsager self-correction term look like:
z_i = sum_j J_ij z_j - f(z_i) + x_i
where f
is a neural network parametrizing the self-correction term for
every site and x_i
denote the input injection or magnetic fields applied
at site i
. Mean-field results are obtained by dropping the self-
correction term. This model generalizes the current generation of transformers in the sense that its couplings are free parameters independent of the incoming data x_i
.
DEQAdaTAPMeanFieldAttention
Slow and explicit deep implicit attention as introduced in https://mcbal.github.io/post/deep-implicit-attention-a-mean-field-theory-perspective-on-attention-mechanisms/ (served as grounding and inspiration for fast and neural one above)
Ising-like vector model with multivariate Gaussian prior over spins. Generalization of the application of the adaptive TAP mean-field approach from a system of binary/scalar spins to vector spins. Schematically, the fixed-point mean-field equations including the Onsager term look like:
S_i ~ sum_j J_ij S_j - V_i S_i + x_i
where the V_i are self-corrections obtained self-consistently and x_i
denote the input injection or magnetic fields applied at site i
. The
linear response correction step involves solving a system of equations,
leading to a complexity ~ O(N^3*d^3)
. Mean-field results are obtained
by setting V_i = 0
.
Given the couplings between spins and a prior distribution for the single- spin partition function, the adaptive TAP framework provides a closed-form solution in terms of sets of equations that should be solved for a fixed point. The algorithm is related to expectation propagation (see Section 4.3 in https://arxiv.org/abs/1409.6179) and boils down to matching the first and second moments assuming a Gaussian cavity distribution.
Install package in editable mode:
$ pip install -e .
Run tests with:
$ python -m unittest
On variational inference, iterative approximation algorithms, expectation propagation, mean-field methods and belief propagation:
On the adaptive Thouless-Anderson-Palmer (TAP) mean-field approach in disorder physics:
On Boltzmann machines and mean-field theory:
On deep equilibrium models:
On approximate message passing (AMP) methods in statistics: