CS6910: Fundamentals of Deep Learning

Lecture 16: Transformers:Multi-headed self-attention, cross attention

Mitesh M. Khapra, Arun Prakash A

AI4Bharat, Department of Computer Science and Engineering, IIT Madras

Module 16.1 : Limitations of sequential model

AI4Bharat, Department of Computer Science and Engineering, IIT Madras

Mitesh M. Khapra, Arun Prakash A

I

enjoyed

the

movie

transformers

h_1
h_2
h_3
h_4
h_5
h_0

Encoder RNN

Decoder RNN

h_i:\text{Hidden States of encoder RNN}

The final state \(h_5\) is called a concept vector, thought vector or annotation

Sequence to Sequence 

Sequence to Sequence 

I

enjoyed

the

movie

transformers

h_1
h_2
h_3
h_4
h_5
h_0

Encoder RNN

Decoder RNN

s_0=h_5
h_i:\text{Hidden States of encoder RNN}
s_i:\text{Hidden States of decoder RNN}

Naan

transfarmar

padaththai

rasiththen

Naan

transfarmar

padaththai

s_1
s_2
s_3

The final state \(h_5\) is called a concept vector, thought vector, context vector or annotation

However, it doesn't care about the alignment of words between source sentence and target sentence

Attention mechanism resolves this.

Attention: A quick tour

I

enjoyed

the

film

transformers

h_1
h_2
h_3
h_4
h_5
h_0

RNN Encoder

Attention a quick Tour:

I

enjoyed

the

film

transformers

h_1
h_2
h_3
h_4
h_5
h_0

RNN Encoder

Suppose that we create a copy of  these hidden state vectors \((h_1,\cdots,h_5\)) and make those available to the decoder.

Once all these vectors are available to the decoder we can  drop the encoder block (until decoding completes)

The attention mechanism consumes these vectors as one of the inputs.

h_1
h_2
h_3
h_4
h_5
h_1
h_2
h_3
h_4
h_5
\sum
\alpha_{11}
\alpha_{12}
\alpha_{13}
\alpha_{14}
\alpha_{15}
c_1
s_0
\begin{aligned} \mathbf{c}_t &= \sum_{i=1}^n \alpha_{ti} \boldsymbol{h}_i & \small{\text{; Context vector for output }y_t} \end{aligned}

\(n\), number of words

\(t\)= time step for decoder

\begin{bmatrix} \alpha_{11}\\ \alpha_{12}\\ \alpha_{13}\\ \alpha_{14}\\ \alpha_{15} \end{bmatrix}
\begin{bmatrix} |&|&|&|&|\\ h_1&h_2&h_3&h_4&h_5\\ |&|&|&|&| \end{bmatrix}
s_1

Naan

I

enjoyed

the

film

transformers

<Go>

h_1
h_2
h_3
h_4
h_5
\sum
\alpha_{21}
\alpha_{22}
\alpha_{23}
\alpha_{24}
\alpha_{25}
c_2
\begin{bmatrix} \alpha_{21}\\ \alpha_{22}\\ \alpha_{23}\\ \alpha_{24}\\ \alpha_{25} \end{bmatrix}
\begin{bmatrix} |&|&|&|&|\\ h_1&h_2&h_3&h_4&h_5\\ |&|&|&|&| \end{bmatrix}
s_1

Naan

s_2

transfarmar

I

enjoyed

the

film

transformers

s_0

<Go>

h_1
h_2
h_3
h_4
h_5
\sum
\alpha_{31}
\alpha_{32}
\alpha_{33}
\alpha_{34}
\alpha_{35}
c_3
\begin{bmatrix} \alpha_{31}\\ \alpha_{32}\\ \alpha_{33}\\ \alpha_{34}\\ \alpha_{35} \end{bmatrix}
\begin{bmatrix} |&|&|&|&|\\ h_1&h_2&h_3&h_4&h_5\\ |&|&|&|&| \end{bmatrix}
s_1

Naan

s_2

transfarmar

I

enjoyed

the

film

transformers

s_3

padaththai

s_0

<Go>

h_1
h_2
h_3
h_4
h_5
\sum
\alpha_{41}
\alpha_{42}
\alpha_{43}
\alpha_{44}
\alpha_{45}
c_4
\begin{bmatrix} \alpha_{41}\\ \alpha_{42}\\ \alpha_{43}\\ \alpha_{44}\\ \alpha_{45} \end{bmatrix}
\begin{bmatrix} |&|&|&|&|\\ h_1&h_2&h_3&h_4&h_5\\ |&|&|&|&| \end{bmatrix}
s_1

Naan

s_2

transfarmar

I

enjoyed

the

film

transformers

s_3

padaththai

s_4

rasiththen

s_0

<Go>

h_1
h_2
h_3
h_4
h_5
\sum
\alpha_{41}
\alpha_{42}
\alpha_{43}
\alpha_{44}
\alpha_{45}
c_4
s_1

Naan

s_2

transfarmar

I

enjoyed

the

film

transformers

s_3

padaththai

s_4

rasiththen

Alignment of words:

Naan

transfarmar

padaththai

rasiththen

the

I

enjoyed

film

transformers

\alpha_{ti}=align(y_t,h_i)
\begin{aligned} &= \frac{\exp(\text{score}(\boldsymbol{s}_{t-1}, \boldsymbol{h}_i))}{\sum_{i'=1}^n \exp(\text{score}(\boldsymbol{s}_{t-1}, \boldsymbol{h}_{i'}))} \end{aligned}
y_t \rightarrow
s_0

<Go>

the

I

enjoyed

film

transformers

Naan

transfarmar

padaththai

rasiththen

\alpha_{11}
\alpha_{25}
\alpha_{34}
\alpha_{42}
h_1
h_2
h_3
h_4
h_5
\sum
\alpha_{41}
\alpha_{42}
\alpha_{43}
\alpha_{44}
\alpha_{45}
c_4
s_1

Naan

s_2

transfarmar

I

enjoyed

the

film

transformers

s_3

padaththai

s_4

rasiththen

\alpha_{ti}=align(y_t,h_i)
\begin{aligned} &= \frac{\exp(\text{score}(\boldsymbol{s}_{t-1}, \boldsymbol{h}_i))}{\sum_{i'=1}^n \exp(\text{score}(\boldsymbol{s}_{t-1}, \boldsymbol{h}_{i'}))} \end{aligned}

QP: Can \(h_{i}\),for all \(i\), be computed in parallel ?

QP: Can \(\alpha_{ti}\) be computed in parallel for all \(i\) at time step \(t\)?

s_0

<Go>

Yes. \(h_i\) is available for all \(i\) and \(s_{t-1}\) is also available at time step \(t\).

No!

Take away: Attention can be parallelized

Approaches to implement score function

Content-Base attention

cosine(s_{t-1},h_i)

Additive (concat) attention

Dot product attention

v_a^T \ tanh(W_a[s_{t-1}:h_i])
s_{t-1}^Th_i

Scaled Dot product attention

\frac{s_{t-1}^Th_i}{\sqrt{n}}

All score functions take in two vectors and produce a scalar. 

Major Limitation

Everything about the RNN based sequence-to-sequence model seems good so far.

They performed well for translation using an attention mechanism.

However, there is a major limitation in training the model.

Given a training example, we can't parallelize the sequence of computations across time steps.

I

enjoyed

the

movie

transformers

h_1
h_2
h_3
h_4
h_5
h_0
s_0=h_5

Naan

transfarmar

padaththai

rasiththen

Naan

transfarmar

padaththai

s_1
s_2
s_1
s_2
s_3

Wishlist: come up with a new architecture that incorporates the attention mechanism and also allows parallelization (and of course, get rid of vanishing/exploding gradient problems)

Module 16.2 : Attention is all you need

AI4Bharat, Department of Computer Science and Engineering, IIT Madras

Mitesh M. Khapra

Transformers: Attention is all you need

*RNN seq2seq models

I

enjoyed

the

movie

transformers

h_1
h_2
h_3
h_4
h_5
h_0
s_0=h_5

Naan

transfarmar

padaththai

rasiththen

Naan

transfarmar

padaththai

s_1
s_2
s_1
s_2
s_3

Transition to Transformers

Encoder

Decoder

h_1
h_2
h_3
h_4
h_5

Naan

transfarmar

padaththai

rasiththen

Naan

transfarmar

padaththai

s_1
s_2
s_1
s_2
s_3

Transition to Transformers

Encoder

Decoder

With Attention

\begin{bmatrix} \alpha_{11}\\ \alpha_{12}\\ \alpha_{13}\\ \alpha_{14}\\ \alpha_{15} \end{bmatrix}

I

enjoyed

the

movie

transformers

Naan

transfarmar

padaththai

rasiththen

Transition to Transformers

Encoder

Decoder

Feed Forward Networks

Self-Attention

Self-Attention

Encoder-Decoder Attention

Feed Forward Networks

We will see each of these components (and a few more) one at a time in detail and connect them together to synthesize the final architecture

Word Embedding

h_1
h_2
h_3
h_4
h_5
s_1
s_2
s_3
s_4
s_5
z_1
z_2
z_3
z_4
z_5

I

enjoyed

the

movie

transformers

Self-Attention

Encoder

Self-Attention

Note that the inputs are vectors (word embeddings) and the outputs are also vectors

1 0 0
0 1 0
1 1 0
1 0 1
1 1 1

Well, we know what attention is. All it requires is a pair of vectors as input.

You can think of these word embeddings as \(h_i\) in RNN encoder (if you wish to compare)

Self-Attention 

Let's take another example sentence

"The animal didn't cross the street because it was too tired".

The word "it" in the sentence refers to "Animal" or "Street"?

We know "it" is referring to the word "Animal"

Let's modify the sentence:

"The animal didn't cross the street because it was congested".

Now the word "it" is referring to the word "Street"

Therefore, it is important to establish a strong connection between the word "it" to the word "street" or "animal"  based on the context.

It calls for an Attention mechanism. That's why it is called a self-attention (to distinguish from cross-attention which we will see later)

The goal

Given a word in a sentence, we want to compute the relational score between the word and the rest of the words in the sentence.

The animal didn't cross the street bacause it 
The
animal
didn't
cross
the
street
because
it

The goal

Given a word in a sentence, we want to compute the relational score between the word and the rest of the words in the sentence....

The animal didn't cross the street bacause it 
The 0.6 0.1 0.05 0.05 0.02 0.02 0.02 0.1
animal 0.02 0.5 0.06 0.15 0.02 0.05 0.01 0.12
didn't 0.01 0.35 0.45 0.1 0.01 0.02 0.01 0.03
cross .
the .
street .
because .
it 0.01 0.6 0.02 0.1 0.01 0.2 0.01 0.01

...such that the score is higher if they are related contextually

The goal

The animal didn't cross the street bacause it 
The 0.6 0.1 0.05 0.05 0.02 0.02 0.02 0.1
animal 0.02 0.5 0.06 0.15 0.02 0.05 0.01 0.12
didn't 0.01 0.35 0.45 0.1 0.01 0.02 0.01 0.03
cross .
the .
street .
because .
it 0.01 0.6 0.02 0.1 0.01 0.2 0.01 0.01

We can think of headers in the first column as \(s_{i}\) and  headers in the first row as \(h_j\) (just for convenience)

s_{i}
h_{j}

The goal

The animal didn't cross the street bacause it 
The 0.6 0.1 0.05 0.05 0.02 0.02 0.02 0.1
animal 0.02 0.5 0.06 0.15 0.02 0.05 0.01 0.12
didn't 0.01 0.35 0.45 0.1 0.01 0.02 0.01 0.03
cross .
the .
street .
because .
it 0.01 0.6 0.02 0.1 0.01 0.2 0.01 0.01

However, now both the vectors \((s_i)\) and \((h_j)\), for all \(i,j\) are available for all the time (whereas in the seq2seq model, \(h_j\) for all \(j\) were available and \(s_i\) was obtained one at a time)

s_{i}
h_{j}

Does it allow us to compute the values for the rows parallelly (i.e., all at a time) ?

Choice of Attetnion function

Recall, the score function \(e_{jt}\) used in the seq2seq model

score(s_{t-1},h_j)=V_a^T \tanh(U_{att}s_{t-1}+W_{att}h_j)
\bigg\{
\bigg\{
  • Two Linear transformations

There are three vectors \((s,h,v)\) involved in computing the score at each time step (of decoder)

  • One non-linearity

Choice of Attetnion function

score(s_{t-1},h_j)=V_a^T
  • Two Linear transformations

There are three vectors \((s,h,v)\) involved in computing the score for each time step (of decoder)

  • One non-linearity

  • Finally Dot-Product

Choice of Attetnion function

However, the input to the self-attention module is only the word embeddings \(h_j\) for all \(j\)

So, we need to get three vectors for each word embedding. How do we do it?

Matrix transformation!. How many matrices do we need?

3

The value of elements in the matrix?

score(s_{t-1},h_j)=V_a^T
  • Two Linear transformations

There are three vectors \((s,h,v)\) involved in computing the score for each time step (of decoder)

  • One non-linearity

  • Finally Dot-Product

Transformation Matrices

=
W_Q
h_j
q_j
=
W_K
h_j
k_j
=
W_V
h_j
v_j

\(q_j\) is called query vector for the word embedding \(h_j\)

\(k_j\) is called key vector for the word embedding \(h_j\)

\(v_j\) is called value vector for the word embedding \(h_j\)

\(W_Q,W_K \ and \ W_V\) are called respective linear transformation (parameter) matrices.

Animal \(\mathbb{R^3}\)

It \(\mathbb{R^3}\)

Animal \(\mathbb{R^2}\)

It \(\mathbb{R^2}\)

W \in R^{2 \times 3}
W_Q \in \mathbb{R}^{64 \times 512}
W_K \in \mathbb{R}^{64 \times 512}
W_V \in \mathbb{R}^{64 \times 512}
h_j \in \mathbb{R}^{512 \times 1}
q \in \mathbb{R}^{64 \times 1}
k \in \mathbb{R}^{64 \times 1}
v \in \mathbb{R}^{64 \times 1}

I

enjoyed

the

movie

transformers

Self-Attention

1 0 0
0 1 0
1 1 0
1 0 1
1 1 1

Let's focus on first calculating the first output from self-attention layer

I

transformers

Self-Attention

1 0 0
1 1 1
0.3 0.2
\cdots
\cdots
\cdots
0.1 0.5
-0.1 0.25
0.11 0.89
0 0.4
0.2 0.7
k_1
v_1
q_1
k_5
v_5
q_5
W_K
W_V
W_Q
W_k
W_V
W_Q
z_1
score(s_{t-1},h_j)

Fixed 

variable

score(q_1, \ \ k_j)

Score func: dot product

e_{1}=[q_1 \cdot k_1, \quad q_1 \cdot k_2, \quad \cdots, \quad \quad \cdots \quad \quad q_1 \cdot k_5]
\alpha_{1j} = softmax(e_{1j})
z_{1} = \sum \limits_{j=1}^5 \alpha_{1j}v_j

What about the \(z_2\)?

Let's focus on first calculating the first output from self-attention layer

I

transformers

Self-Attention

1 0 0
1 1 1

Let's focus on first calculating the first output from self-attention layer

0.3 0.2
\cdots
\cdots
\cdots
0.1 0.5
-0.1 0.25
0.11 0.89
0 0.4
0.2 0.7
k_1
v_1
q_1
k_5
v_5
q_5
W_K
W_V
W_Q
W_k
W_V
W_Q
z_1
score(s_{t-1},h_j)

Fixed 

variable

score(q_2, \ \ k_j)
e_{2j}=[q_2 \cdot k_1, \quad q_2 \cdot k_2, \quad \cdots, \quad \quad \cdots \quad \quad q_2 \cdot k_5]
\alpha_{2j} = softmax(e_{2j})
z_{2} = \sum \limits_{j=1}^5 \alpha_{2j}v_j
z_1
z_2

Repeat the procedure for all other \(z\).

Score func: dot product

Wait, can we vectorize all these computations and compute the outputs (\(z_1,z_2,\cdots,z_T\)) in one go?

=
W_Q
h_1
q_T
h_2
h_T
q_2
q_1
\cdots
\cdots
\Bigg[
\Bigg]
\Bigg[
\Bigg]
Q
Q \in \mathbb{R}^?
Q \in \mathbb{R}^{64 \times T}

Wait, can we vectorize all these computations and compute the outputs (\(z_1,z_2,\cdots,z_T\)) in one go?

=
W_K
h_1
k_T
h_2
h_T
k_2
k_1
\cdots
\cdots
\Bigg[
\Bigg]
\Bigg[
\Bigg]
K
K \in \mathbb{R}^?
K \in \mathbb{R}^{64 \times T}

Wait, can we vectorize all these computations and compute the outputs (\(z_1,z_2,\cdots,z_T\)) in one go?

=
W_V
h_1
v_1
h_2
h_T
v_2
v_1
\cdots
\cdots
\Bigg[
\Bigg]
\Bigg[
\Bigg]
V
V \in \mathbb{R}^?
V \in \mathbb{R}^{64 \times T}
\begin{aligned} Z &= [z_1,z_2,\cdots,z_T] \\ &=softmax\big(\frac{Q^TK}{\sqrt{d_k}}\big)V^T \end{aligned}

Where \(d_k\) is the dimension of key vector.

Since \(d_k\) scales the values of \(Q^TK\), it is called a scaled-dot product.

dim(Q^TK): T \times 64 \times64 \times T = T \times T
dim(Z): T \times T \times T \times 64 = T \times 64

Vectorized Output 

Q
K
V
\text{MatMul:} \ Q^TK
\text{Scale}:\frac{1}{\sqrt{d_k}}
\text{Softmax}
\text{MatMul}
Q
K
V
\text{MatMul:} \ Q^TK
\text{Scale}:\frac{1}{\sqrt{d_k}}
\text{Softmax}
\text{MatMul}

Scaled Dot Product

Head

Q
K
V
W_Q
W_K
W_V
H=\{h_1,h_2,\cdots,h_T\}

Two-head Attention

Scaled Dot Product

Head-1

Q
K
V
W_Q
W_K
W_V
H=\{h_1,h_2,\cdots,h_T\}

Scaled Dot Product

Head-2

Q
K
V
W_Q
W_K
W_V
H=\{h_1,h_2,\cdots,h_T\}

Motivation for Multi-Head attention

What is the significance of having more than one filter/kernel in a CNN layer?

To learn more abstract representations, capture more meaningful interactions between inputs

Similarly, we can have more than one self-attention heads with different parameter matrices \((W_Q^i,W_K^i,W_V^i)\) with a hope that it learns subtle contextual information.

This motivates "Multi-head Attention", which is a simple extension of single-head attention

Like, each kernel independently learns its feature in CNN, each head independently computes the attention in Transformers. (Parallel computation!)

Motivation 

Single-Head

The word "it" is strongly connected to the word "was".

Link to the colab: Tensor2Tesnor

Motivation 

Two-Head

The word "it" is strongly connected to the word "was" in the first head

The word "it" is strongly connected to the word "animal" in the second head.

So it is evident (empirically) that adding more than one attention helps in capturing different contextual information of the sentence

So, the multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.

=
W_Q
h_j
q_j

A little more detail

h_j \in \mathbb{R}^{512 \times 1}
W_Q \in \mathbb{R}^{64 \times 512}
q_j \in \mathbb{R}^{64 \times 1}

The dimension of \(q_j\) is much less than \(h_j\). In fact,

The word embedding \(h_j \in \mathbb{R}^{512}\) is projected to a low dimensional representation subspace of size \(q_j,k_j,v_j \in \mathbb{R}^{64}\) by the matrices \(W_Q,W_K\) and \(W_V\).

h_j
h_j
h_j
h_j
h_j
h_j
W_Q^1
W_K^1
W_V^1
W_Q^2
W_K^2
W_V^2

Scaled Dot Product

Attention

Scaled Dot Product

Attention

Two head Attention

Concatenate

Linear

How do we extend this to multi-head attention?

Recall How \(Q,K,V\) are obtained from \(H\).

h_j
h_j
h_j
h_j
h_j
h_j
h_j
h_j
h_j
\cdots
\cdots
\cdots
W_Q^1
W_K^1
W_V^1
W_Q^2
W_K^2
W_V^2
W_Q^8
W_K^8
W_V^8

Scaled Dot Product

Attention

Scaled Dot Product

Attention

Scaled Dot Product

Attention

Concatenate (:\(T \times 512\))

Linear

Multi-Head Attention

Concatenate (:\(T \times 512\))

Linear

Multi-Head Attention 

W_Q^8
W_K^8
W_V^8

Scaled Dot Product

Attention

W_Q^2
W_K^2
W_V^2

Scaled Dot Product

Attention

Scaled Dot Product

Attention

h_j
h_j
h_j
W_Q^1
W_K^1
W_V^1

Concatenate (:\(T \times 512\))

Linear

Multi-Head Attention 

W_Q^8
W_K^8
W_V^8

Scaled Dot Product

Attention

W_Q^2
W_K^2
W_V^2

Scaled Dot Product

Attention

Scaled Dot Product

Attention

h_j
h_j
h_j
W_Q^1
W_K^1
W_V^1
\small \text{MultiHead}(Q,K,V) =\small \text{Concatenate}(head_1,\cdots,head_8)
head_i=Attention
W_O
(Q^i,K^i,V^i)

The input is projected into \(h=8\) different representation subspaces.

So, the multi-head attention allows the model to jointly

attend to information from different representation

subspaces at different positions.

I

enjoyed

the

movie

transformers

Encoder

Feed Forward Network

Self-Attention

Back to Basic Block