Implementation of plug in and play Attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens"
ParallelWrapper
ClassParallelWrapper
class to simplify the usage of data parallelism.ParallelWrapper
class:
use_data_parallel
to enable or disable data parallelism.nn.DataParallel
to the model accordingly.DilatedAttention
ModelDilatedAttention
model using the ParallelWrapper
class.DilatedAttention
model should be loaded.cuda:0
) if CUDA is available; otherwise, it defaults to CPU.ParallelWrapper
with the DilatedAttention
model.The key addition was the ParallelWrapper
class to facilitate easy and configurable usage of data parallelism with the provided DilatedAttention
model. This ensures scalability across multiple GPUs without any significant change in the existing workflow. The user can now enable or disable data parallelism using a single flag.
Tensor Shape Adjustments:
Ensured the consistent shape of tensors across all operations.
Squeezed a_indices
to 2D to match dimensions of att_denom_sums
.
a_indices = a_indices[:, :, 0].squeeze(-1).squeeze(-1)
Sliced a_indices
to the unpadded sequence length before scattering.
a_indices = a_indices[:, :unpadded_seq_len]
Scatter and Gather Operations:
Scatter with squeezed 2D a_indices
and gather sparse sums with these indices.
att_denom_sums.scatter_add_(1, a_indices, a_denoms)
sparse_att_denom_sum = torch.gather(att_denom_sums, 1, a_indices)
DataType Handling:
torch.int64
(or torch.long
) to ensure compatibility with PyTorch's indexing operations.torch.float16
dtype for the 'X' tensor to make it memory-efficient.Code Cleaning:
Validation Checks:
Enhanced Error Messages:
Optimizations:
Edge Case Handling:
head_idx
.Other Minor Fixes:
Documentation:
Bug: The size mismatch in tensor operations in the forward method of the DilatedAttentionLLAMA
class.
Bug: Index out of range error while transposing tensors.
Optimized Tensor Operations: The tensor operations in the forward method were optimized to ensure they all operate on tensors with matching dimensions, improving the efficiency of the model.
Added Error Handling: We added checks for dimension mismatches in tensor operations to throw useful error messages when the input data does not match the expected shape.
DilatedAttentionLLAMA Class: Introduced a new DilatedAttentionLLAMA class that uses dilated attention mechanism for the forward method. This new implementation is designed to be more efficient for larger sequence lengths.
Performance Testing: Added a simple performance test to benchmark the speed of the forward method in the DilatedAttentionLLAMA class.
Changelog Bug Fixes Issue: ValueError: too many values to unpack (expected 3)
Root Cause: The attention function was returning more than three values, but the code was trying to unpack its return values into only three variables. Resolution: Modified the line where the attention function is called to collect all additional return values into a list using the * operator. Issue: RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 1
Root Cause: The code was trying to add two tensors of different sizes in the forward method of the DynamicDilatedAttention class. Resolution: Modified the line where the tensors are added to ensure that attn_output has the same size as the corresponding slice of outputs before trying to add them. Issue: ValueError: not enough values to unpack (expected 7, got 6)
Root Cause: The flash_attn function in the FlashAttention class was trying to unpack the shape of the q tensor into seven variables, but the q tensor only had six dimensions. Resolution: Modified the forward method of the DilatedAttention class to reshape the x tensor correctly before passing it to the attention function. Improvements Improvement: Added assertions to check the types and values of the parameters in the init method of the DilatedAttention class to prevent incorrect usage.
Improvement: Added a check for the Distributed parameter in the init method of the DilatedAttention class to decide whether to use the DataParallel wrapper for the FlashAttention modules.
Improvement: Modified the forward method of the DilatedAttention class to process each segment of the input separately for each attention head, allowing the attention heads to share information between different segments.
Improvement: Modified the forward method of the DilatedAttention class to use a buffer to store the attn_output_resized tensor instead of creating a new tensor of zeros in every forward pass, improving efficiency.