Model | Context length |
---|---|
LLAMA-2 | 4K |
Mistral-7B | 8K |
GPT-4, Gemini 1.0 | 32K |
Mosaic ML MPT | 65K |
Anthropic Claude | 100K |
Gemini 1.5 pro, GPT-4 Turbo | 128K |
Claude 2.1 | 200K |
Gemini 1.5 pro (limited) | 2 million (for public) 10 million (experimental) |
Strided Local Attention
Strided Local Attention
Random Local Attention
Sparse Block Attention
Strided Local Attention
Random Local Attention
Sparse Block Attention
Strided Local Attention
Global+local attention
Random Local Attention
Sparse Block Attention
Strided Local Attention
Global+local attention
Random Local Attention
Flash Attention
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
(the score for the word itself (\(q_ik_i^T\)))
\(O(c \times T \times d)\) where \(c < T\)
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Linear Computational complexity \(O(cTd)\)
\(c=3\)
\(O(3Td)\)
\(c=3\)
\(c=5\)
\(O(5Td)\)
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
\(O(3Td)\)
\(c=3\)
\(O((3T+T)d)\)
\(c=3\)
[cls] Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
\(O((3T+T+T)d)\)
\(c=3\)
[cls] Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Our planet is a lonely speck
Computational complexity \(O(\frac{T^2}{2})\)
MHA
MHA
Merge
50% non-zero values
6.25% non-zero values
(50% sparse)
(93% sparse)
\(hash(q_1)= \textit{argmax}(concat(0.5,0.25,-0.5,-0.25))=0\)
\(hash(k_2)= \textit{argmax}(concat(-0.25,-0.5,0.25,0.5))=3\)
\(hash(q_1)= \textit{argmax}(concat(0.5,0.25,-0.5,-0.25))=0\)
\(hash(k_2)= \textit{argmax}(concat(-0.25,-0.5,0.25,0.5))=3\)
Query | hash by R1 | hash by R2 |
---|---|---|
0 | 2 | |
3 | 2 | |
0 | 2 | |
0 | 2 |
Query | hash by R1 | hash by R2 |
---|---|---|
0 | 2 | |
0 | 2 | |
0 | 2 | |
3 | 2 |
\(L=T\) in the figure below
Main Memory (DRAM)
Compute
(Registers)
Read (a,b, and mul instruction)
write (c)
Main Memory (DRAM)
Compute
(Registers)
Read (c, and sqrt instruction)
write (d)
compute is idle
Main Memory (DRAM)
Compute
(Registers)
Read (a,b, mul,sqrt)
write (c)
Naive implementation
Naive implementation
\(P=ConCat(\beta_0f_{00},\beta_1f_{01},\cdots, \beta_8 f_{22}\))
where,
\(x=[x^A,x^B]\) where,\(x^A\): -1.02 -1.09 0.24 0.87 -1.32, \(x^B\): -0.6 -5.04 -2. 1.03 -0.14
\(x^A\): -1.02 -1.09 0.24 0.87 -1.32 \(x^B\): -0.6 -5.04 -2. 1.03 -0.14
\(m_1=max(x^A)\)
0.15 0.14 0.53 1. 0.11
\(m_2=max(x^B)\)
0.2 0. 0.05 1. 0.31
Combine
Split
\(m=max(m_1,m_2))\)
0.13 0.12 0.45 0.85 0.1 0.2 0. 0.05 1. 0.31
0.04 0.04 0.14 0.27 0.03 0.06 0. 0.01 0.31 0.1
For, \(d=64\) and assigning 4 bytes per element, the \(B_c=\frac{100 \times 1024 \ B}{4 \times 64 \times 4 \ B} = 100, B_r=64 \)
For, \(d=64\) and assigning 4 bytes per element, the \(B_c=\frac{100 \times 1024 \ B}{4 \times 64 \times 4 \ B} = 100, B_r=64 \)
Suppose \(N=128, d=16,B_c=16,B_r=32\)
A zero vector
Suppose \(N=128, d=16,B_c=16,B_r=32\)
Suppose \(N=128, d=16,B_c=16,B_r=32\)
Suppose \(N=128, d=16,B_c=16,B_r=32\)
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
in SRAM
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
in SRAM
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
in SRAM
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
in SRAM
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
in SRAM
Suppose \(N=128, d=16,B_c=16,B_r=32\)
in HBM
Embedding
Embedding
Embedding
Embedding
Embedding
KV cache per token (in B) =
2 \(\times\) num_layers \(\times\) (num_heads \(\times\) dim_head) \(\times\) precision_in_bytes
Model GPT-3
KV cache per token = 2 \(\times\)96\(\times\) (96 \(\times\) 128) \(\times\) 4
= 9.4 MB
For 2K context length = 2048 \(\times\) 9.4 MB
= 19.3 GB
KV cache per token (in B) =
2 \(\times\) num_layers \(\times\) (num_heads \(\times\) dim_head) \(\times\) precision_in_bytes
Cache Memory
Mistral 7B uses Sliding window attention
Can we sparsify Causal mask attention?
What happens during inference?
What happens for masked attention? Remember we add mask matrix to QK^T, so it does not matter (my guess)
Local attention and blockwise attention trades-off performance
Can we replace a transformer model trained with full attention by sparse attention during inference?
New attention layer: Lambda Layer
Various Ideas
New attention layer: Lambda Layer
Methods that embraces the limitation and a build a mechanism around it
Pre-trained Model with parametric memory (implicit knowledge)
World knowledge in non-parametric memory (explicit-knowledge)
Pre-trained Neural Retriever
2. Sliding Window
Second line of work questions if full attention is essential and have tried to come up with approaches that do not require full attention, thereby reducing the memory and computation requirements.
Unidirectional:
Bidirectional:
these approximations do not come with theoretical guarantees.
Random Attention: (From graph theory perspective)
Connected
Local attention
Random attention
The weight of an edge denotes the similarity score between nodes
its second eigenvalue (of the adjacency matrix) is quite far from the first eigenvalue
Global attention:
make some existing tokens “global”, which attend over the entire sequence
Two variants: ITC, ETC (extended tokens)
What aspects of the self-attention model are necessary for its performance?
Can we achieve the empirical benefits of a fully quadratic self-attention scheme using fewer inner-products?
Longformer: text 8, enwiki8, pre-train-fine-tune
Context Length of todays LLMs
GPT-4, Gemini 1.0 : 32K
Mosaic ML MPT: 65K
Anthropic Claude: 100K
LLAMA-2: 4K, Attention: GQA
Mistral-7B: 8K, Attention: Local attention (SWA exploits the stacked layers of a transformer to attend information beyond the window size W)
Gemini 1.5-pro (standard), GPT-4 Turbo: 128K
Gemini 1.5-pro: 1000K
Claude 2.1 : 200 K
Attention | text8 | enwiki8 | TraviaQA | OpenKP | hotpotQA | WikiHop | WikiText-103 | ||
---|---|---|---|---|---|---|---|---|---|
LongFormer | yes | yes | yes | yes | yes | ||||
Compressive transformer | yes | yes |
Text-8, enwiki8, WikiText-103: LM modelling
LongFormer
a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention.
scales linearly with the sequence length
can act as a drop-in replacement for the self-attention mechanism in pretrained Transformers (take existing check point, and Longformer attention during fine tuning)
We show that BIGBIRD is a universal approximator of sequence functions and is Turing complete, thereby preserving these properties of the quadratic, full attention model.
Visual inspection showed that most layers had sparse attention patterns across most data points, suggesting that some form of sparsity could be introduced without significantly affecting performance. [Ref]
Motivation for sparse attention
Brief Intro to GPU architecture
Compute: Arithmetic and other instructions are executed by the SMs (streaming multiprocessors)
Storage: data and code are accessed from DRAM (a.k.a, High Band width Memory (HBM)) via the L2 cache
Model | num of SM | L2 cache | HBM (DRAM) | Band Width |
---|---|---|---|---|
A100 | 108 | 40MB | 80 GB | upto 2TB/s |
Streaming Multiprocessors
Each SM has its own instruction schedulers and various instruction execution pipelines
Core operation: Multiply and add
There are 512 TF32 cores in a single SM of A100.
Each core execute one Multiply and Add operation per clock
Therefore, FLOPs per clock is simply two times the number of cores. That is, 1024 FLOPs per clock
With 1.4 GHz clock frequency, this translates to 1.43 TFLOPS (Tera FLOPs per Second)
With 108 SM's, the overall peak throughput is 154 TF32 TFLOPS
Streaming Multiprocessors
At runtime, a thread block is placed on an SM for execution, enabling all threads in a thread block to communicate and synchronize efficiently
Launching a function with a single thread block would only give work to a single SM, therefore to fully utilize a GPU with multiple SMs one needs to launch many thread blocks
Since an SM can execute multiple thread blocks concurrently, typically one wants the number of thread blocks to be several times higher than the number of SMs
Launching a function with a single thread block would only give work to a single SM, therefore to fully utilize a GPU with multiple SMs one needs to launch many thread blocks. (The reason why people ignore batch of samples if number of samples is less than batch size)
How tensor cores differ from Cuda Cores?
Tensor cores accelerate matrix multiply and accumulate operations by operating on small matrix blocks (say, 4x4 blocks)
When math operations cannot be formulated in terms of matrix blocks they are executed in other CUDA cores
Performance Measure
Task: Read ips - Do Math - Write ops
\(T_{mem}\): Time to read inputs from memory
\(T_{math}\): Time to perform the operation
Total time= \(max(T_{mem},T_{math}\))
If math time is longer we say that a function is math limited (bounded), if memory time is longer then it is memory limited (bounded).
Suppose the algorithm is memory-bound
then we can write
arithmetic intensity
ops by bytes ratio
Multi-head attention is Memory bound
The arithmetic intensity is less than the OPs per Byte ratio
That is, the algorithm spend more time on reading and writing to memory than on the math operations (compute)!
How much time is spent in memory or math operations depends on both the algorithm and its implementation, as well as the processor’s bandwidths
Can we re-write the attention algorithm such that it spends less time in reading from and writing to memory? That is, can we translate all element wise operation in MHA to mat-mul ops? or use kernel fusion? Make it IO aware
To Do:
Display the algorithm for self-attention
Say we have two matrices of size 100 by 100 and want to do matrix multiplication
First we need to access all the 20000 elements, amounts to 80KB (4 byte per element)
Compute is: \(2 \times 100^3\)= 20 million
Artithmatic intensity = \(\frac{20 \times 10^6}{80 \times 10^3}=250\)
Ops-per-Byte = \(\frac{19 \times 10^{12} }{1 \times 10^{12}}=19\)
Say we have a matrix of size 100 by 100 and want to scale it by 2
First we need to access all the 10000 elements, amounts to 40KB (4 byte per element)
Compute is: \(10000\)= 10000
Artithmatic intensity = \(\frac{10000}{40000}=0.25\)
Ops-per-Byte = \(\frac{19 \times 10^{12} }{1 \times 10^{12}}=19\)