Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 3: Coding Attention Mechanisms

ch3.livemd

Chapter 3: Coding Attention Mechanisms

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:tiktoken, "~> 0.3.2"},
  {:table_rex, "~> 3.1.1"},
])

Introduction

There are four different variants of attention mechanisms:

  • Simplified self-attention is just to introduce the broader idea.

  • Self-attention, the basis of the mechanism used in LLMs, with trainable weights.

  • Causal attention, a type of self attention used in LLMs that allows a model to consider only previous and current inputs in a sequence, ensuring temporal order during the text generation.

  • Multi-head attention, is an extension of self-attention and causal attention that enables the model to simultaneously attend to information from different representation subspaces.

3.1 The problem with modeling long sequences

Before the advent of transformers, recurrent neural networks (RNNs) were the most popular encoder-decoder architecture for language translation.

An RNN is a type of neural network where outputs from previous steps are fed as inputs to the current step, making them well-suited for sequential data like text.

In an encoder-decoder RNN, the input text is fed into the encoder, which processes it sequentially. The encoder updates its hidden state, at each step, trying to capture the entire meaning of the input sentence in the final hidden state.

The decoder then takes this final hidden state to start generating the translated sentence, one word at a time. It also updates its hidden state at each step, which is supposed to carry the context necessary for the next-word prediction.

The big issue and limitation of encoder-decoder RNNs is that the RNN can’t directly access earlier hidden states from the encoder during the decoding phase. Consequently, it relies solely on the current hidden state, which encapsulates all relevant information. This can lead to a loss of context, especially in complex sentences where dependencies might span long distances.

3.2 Capturing data dependencies with attention mechanisms

RNNs work fine for translating short sentences but don’t work well for longer texts as they don’t have direct access to previous words in the input.

The Bahdanau attention mechanism for RNNs modifies the encoder- decoder RNN such that the decoder can selectively access different parts of the input sequence at each decoding step. The importance is determined by the so-called attention weights

researchers found that RNN architectures are not required for building deep neural networks for natural language processing and proposed the original transformer architecture with a self-attention mechanism inspired by the Bahdanau attention mechanism.

Self-attention is a mechanism that allows each position in the input sequence to attend to all positions in the same sequence when computing the representation of a sequence.

Self-attention is a mechanism in transformers that is used to compute more efficient input representations by allowing each position in a sequence to interact with and weigh the importance of all other positions within the same sequence.

3.3 Attending to different parts of the input with self-attention

Self-attention serves as the cornerstone of every LLM based on the transformer architecture.

In self-attention, the “self” refers to the mechanism’s ability to compute attention weights by relating different positions within a single input sequence. It assesses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image. This is in contrast to traditional attention mechanisms, where the focus is on the relationships between elements of two different sequences, such as in sequence-to-sequence models where the attention might be between an input sequence and an output sequence.

3.3.1 A simple self-attention mechanism without trainable weights

The goal of self-attention is to compute a context vector, for each input element, that combines information from all other input elements.

In self-attention, our goal is to calculate context vectors z^(i) for each element x^(i) in the input sequence. A context vector can be interpreted as an enriched embedding vector.

In self-attention, context vectors play a crucial role. Their purpose is to create enriched representations of each element in an input sequence (like a sentence). This is essential in LLMs, which need to understand the relationship and relevance of words in a sentence to each other.

inputs =
  Nx.tensor(
    [
      # Your (x^1)
      [0.43, 0.15, 0.89],
      # journey (x^2)
      [0.55, 0.87, 0.66],
      # starts (x^3)
      [0.57, 0.85, 0.64],
      # with (x^4)
      [0.22, 0.58, 0.33],
      # one (x^5)
      [0.77, 0.25, 0.10],
      # step (x^6)
      [0.05, 0.80, 0.55]
    ]
  )

The first step of implementing self-attention is to compute the intermediate values ω, referred to as attention.

A dot product is essentially just a concise way of multiplying two vectors element-wise and then summing the products.

The dot product is a measure of similarity because it quantifies how much two vectors are aligned: a higher dot product indicates a greater degree of alignment or similarity between the vectors. In the context of self- attention mechanisms, the dot product determines the extent to which elements in a sequence attend to each other: the higher the dot product, the higher the similarity and attention score between two elements.

Nx.dot(inputs, inputs[1])
{sequence_size, _dim} = inputs.shape

attn_scores =
  for i <- 0..(sequence_size - 1), reduce: [] do
    acc ->
      score = Nx.dot(inputs, inputs[i])
      acc ++ [score]
  end
attn_scores_tensor = Nx.stack(attn_scores)
# for-loops are generally slow, and we can achieve the 
# same results using matrix multiplication
Nx.dot(inputs, Nx.transpose(inputs))

The next step is normalization.

This normalization is a convention that is useful for interpretation and for maintaining training stability in an LLM.

attn_sum = Nx.sum(attn_scores_tensor, axes: [1])
attn_score_norm_tensor = Nx.divide(attn_scores_tensor, attn_sum) |> Nx.transpose()
attn_sum = Nx.sum(attn_score_norm_tensor, axes: [1])
# In practice, it's more common and advisable to use the softmax 
# function for normalization.
softmax_naive = fn x ->
  exp_x = Nx.exp(x)
  exp_x
  |> Nx.divide(Nx.sum(exp_x, axes: [1]))
  |> Nx.transpose()
end

# the softmax function ensures that the attention weights are always positive.
# This makes the output interpretable as probabilities or relative importance, 
# where higher weights indicate greater importance.
attn_score_softmax_tensor = softmax_naive.(attn_scores_tensor)
# To avoid numerical instability problems, such as overflow and underflow, 
# when dealing with large or small input values.
Axon.Activations.softmax(attn_scores_tensor)

The final state is calculating the context vector z^(2) by multiplying the embedded input tokens, x^(i) , with the corresponding attention weights and then summing the resulting vectors

3.3.2 Computing attention weights for all input tokens

context_vector = Nx.dot(attn_score_softmax_tensor, inputs)

Self-Attention Steps:

  • Compute the attention scores as dot products between the inputs.
  • The attention weights are a normalized version of the attention scores.
  • The context vectors are computed as a weighted sum over the inputs.

3.4 Implementing self-attention with trainable weights

Self-attention mechanism is also called scaled dot-product attention.

These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce “good” context vectors.

3.4.1 Computing the attention weights step by step

We are going to use three weight matrices W_q, W_k, and W_v. These three matrices are used to project the embedded input tokens, x^(i), into query, key, and value vectors.

x = Axon.param("random", {1,2})
x_2 = inputs[1]
w_q = Nx.tensor([[0.2961, 0.5166],[0.2517, 0.6886],[0.0740, 0.8665]])
w_k = Nx.tensor([[0.1366, 0.1025],[0.1841, 0.7264],[0.3153, 0.6871]])
w_v = Nx.tensor([[0.0756, 0.1966],[0.3164, 0.4017],[0.1186, 0.8274]])

{w_q, w_k, w_v, x_2}
query_2 = Nx.dot(x_2, w_q)
key_2 = Nx.dot(x_2, w_k)
value_2 = Nx.dot(x_2, w_v)

In summary, weight parameters are the fundamental, learned coefficients that define the network’s connections, while attention weights are dynamic, context-specific values.

queries = Nx.dot(inputs, w_q)
keys = Nx.dot(inputs, w_k)
values = Nx.dot(inputs, w_v)
attn_score_22 = Nx.dot(keys[1], queries[1])
# All attention scores for given query
attn_score_2 =
  keys
  |> Nx.transpose()
  |> then(&amp;Nx.dot(queries[1], &amp;1))

That we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys.

d_k = Nx.axis_size(keys, -1)
attn_weights_2 =
  attn_score_2
  |> Nx.divide(Nx.pow(d_k, 0.5))
  |> Axon.Activations.softmax(axis: -1)

The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. The final step is to compute the context vectors.

context_vec_2 = Nx.dot(attn_weights_2, values)

Why query, key, and value?

The terms “key,” “query,” and “value” in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.

A “query” is analogous to a search query in a database. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

The “key” is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence has an associated key. These keys are used to match with the query.

The “value” in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

3.4.2 Implementing a compact self-attention Python class

defmodule SelfAttentionV1 do
  import Nx.Defn
  
  defn forward(x, w_query, w_key, w_value) do
    keys = Nx.dot(x, w_key)
    queries = Nx.dot(x, w_query)
    values = Nx.dot(x, w_value)

    d_k = Nx.axis_size(keys, -1)

    attn_score =
      keys
      |> Nx.transpose()
      |> then(&amp;Nx.dot(queries, &amp;1))

    attn_weights =
      attn_score
      |> Nx.divide(Nx.pow(d_k, 0.5))
      |> Axon.Activations.softmax(axis: -1)

    Nx.dot(attn_weights, values)
  end
end
context_vector = SelfAttentionV1.forward(inputs, w_q, w_k, w_v)
context_vector[1]

In self-attention, we transform the input vectors in the input matrix X with the three weight matrices, Wq, Wk, and Wv. Then, we compute the attention weight matrix based on the resulting queries (Q) and keys (K). Using the attention weights and values (V), we then compute the context vectors (Z).

defmodule SelfAttentionV2 do
  import Nx.Defn
  
  defn forward(x, w_query, w_key, w_value) do
    keys = Axon.Layers.dense(x, w_key)
    queries = Axon.Layers.dense(x, w_query)
    values = Axon.Layers.dense(x, w_value)

    d_k = Nx.axis_size(keys, -1)

    attn_score =
      keys
      |> Nx.transpose()
      |> then(&amp;Nx.dot(queries, &amp;1))

    attn_weights =
      attn_score
      |> Nx.divide(Nx.pow(d_k, 0.5))
      |> Axon.Activations.softmax(axis: -1)

    Nx.dot(attn_weights, values)
  end
end
context_vector = SelfAttentionV1.forward(inputs, w_q, w_k, w_v)

The causal aspect involves modifying the attention mechanism to prevent the model from accessing future information in the sequence. In language modeling, where each word prediction should only depend on previous words.

The multi-head component involves splitting the attention mechanism into multiple “heads.” Each head learns different aspects of the data, allowing the model to simultaneously attend to information from different representation subspaces at different positions.

3.5 Hiding future words with causal attention

Causal attention, also known as masked attention. It restricts a model to only consider previous and current inputs in a sequence when processing any given token.

3.5.1 Applying a causal attention mask

To obtain the masked attention weight matrix in causal attention is to apply the softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.

keys = Nx.dot(inputs, w_k)
queries = Nx.dot(inputs, w_q)

d_k = Nx.axis_size(keys, -1)

attn_score =
  keys
  |> Nx.transpose()
  |> then(&amp;Nx.dot(queries, &amp;1))

attn_weights =
  attn_score
  |> Nx.divide(Nx.pow(d_k, 0.5))
  |> Axon.Activations.softmax(axis: -1)
mask_simple =
  attn_weights
  |> then(&amp;Nx.broadcast(1, &amp;1))
  |>Nx.tril()
masked = Nx.multiply(attn_weights, mask_simple)
row_sum = Nx.sum(masked, axes: [1]) |> Nx.reshape({6, 1}) |> IO.inspect()
masked_norm = Nx.divide(masked, row_sum)

INFORMATION LEAKAGE

When we renormalize the attention weights after masking, what we’re essentially doing is recalculating the softmax over a smaller subset (since masked positions don’t contribute to the softmax value).

The mathematical elegance of softmax is that despite initially including all positions in the denominator, after masking and renormalizing, the effect of the masked positions is nullified.

There is no information leakage from future.

we can take advantage of a mathematical property of the softmax function and implement the computation of the masked attention weights more efficiently in fewer steps

The softmax function converts its inputs into a probability distribution. When negative infinity values (-∞) are present in a row, the softmax function treats them as zero probability. (Mathematically, this is because e -∞ approaches 0.)

simple_mask =
  attn_weights
  |> then(&amp;Nx.broadcast(Nx.Constants.infinity(), &amp;1))
  |>Nx.triu(k: 1)
masked = Nx.multiply(simple_mask, -1) |> Nx.add(attn_score)
attn_weights =
  masked
  |> Nx.divide(Nx.pow(d_k, 0.5))
  |> Axon.Activations.softmax(axis: -1)

3.5.2 Masking additional attention weights with dropout

Dropout in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively “dropping” them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units.

In the attention mechanism, the dropout is typically applied in two specific areas: after calculating the atttention scores or after appling the attention weights to the value vectors.

key = Nx.Random.key(103)
example = Nx.iota({6,6}, type: :f32) |> Nx.fill(1)

output = Axon.Layers.dropout(example, key, rate: 0.5, mode: :train)

The values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 =2.

This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

dropout = fn(input, key) -> 
  model = Axon.Layers.dropout(input, key, rate: 0.5, mode: :train) 
  {model.output, model.state["key"]}
end
{output, new_key} = dropout.(attn_weights, key)

3.5.3 Implementing a compact causal attention class

batch = Nx.stack([inputs, inputs])
batch.shape
defmodule Attention do
  import Nx.Defn

  def causal(%Axon{} = input, opts \\ []) do
    opts = Keyword.validate!(opts, [:name, :d_in, :d_out])
    w_query = Axon.param("w_query", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_key = Axon.param("w_key", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_value = Axon.param("w_value", fn _ -> {opts[:d_in], opts[:d_out]} end)

    Axon.layer(
      &amp;attention_impl/5,
      [input, w_query, w_key, w_value],
      name: opts[:name],
      op_name: :causal_attention
    )
  end

  defnp attention_impl(input, w_query, w_key, w_value, _opts \\ []) do
    # {b, num_tokens, d_in} = Nx.shape(input)
    keys = Nx.dot(input, w_key)
    queries = Nx.dot(input, w_query)
    values = Nx.dot(input, w_value)

    d_k = Nx.axis_size(keys, -1)

    attn_score =
      keys
      |> Nx.transpose(axes: [0, 2, 1])
      |> then(&amp;Nx.dot(queries, [2], [0], &amp;1, [1], [0]))

    simple_mask =
      attn_score
      |> then(&amp;Nx.broadcast(Nx.Constants.infinity(), &amp;1))
      |> Nx.triu(k: 1)

    masked = Nx.multiply(simple_mask, -1) |> Nx.add(attn_score)

    attn_weights =
      masked
      |> Nx.divide(Nx.pow(d_k, 0.5))
      |> Axon.Activations.softmax(axis: -1)

    Nx.dot(attn_weights, [2], [0], values, [1], [0])
  end
end
model =
  Axon.input("sequence", shape: {nil, 6, 3})
  |> Attention.causal(d_out: 2, d_in: 3, name: "causal")
  |> Axon.dropout(rate: 0.5)
{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1,6,3}, :f32)
params = init_fn.(template, %{})
predict_fn.(params, batch)
# to visualize the dropout, build the model with mode: :train
{init_fn, predict_fn} = Axon.build(model, mode: :train)
template = Nx.template({1,6,3}, :f32)
params = init_fn.(template, %{})
predict_fn.(params, batch)

3.6 Extending single-head attention to multi-head attention

The term “multi-head” refers to dividing the attention mechanism into multiple “heads,” each operating independently. In this context, a single causal attention module can be considered single-head attention, where there is only one set of attention weights processing the input sequentially.

3.6.1 Stacking multiple single-head attention layers

In practical terms, implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights, and then combining their outputs.

defmodule MultiHeadAttentionWrapper do
  def multi_head_model(heads_numbers, heads_args \\ []) do
    for head_index <- 1..heads_numbers do
      head_args = Enum.at(heads_args, head_index - 1, [])
      single_head_model(head_args)
    end
  end

  def build_heads(models, heads_args \\ []) do
    for {model, head_index} <- Enum.with_index(models) do
      head_args = Enum.at(heads_args, head_index - 1, [])
      build_single_head_model(model, head_args)
    end
  end

  def predict(models, input, opts \\ []) do
    axis = Keyword.get(opts, :axis, 2)
    predict_tasks = 
      for {params, _init_fn, predict_fn} <- models do
        Task.async(fn -> predict_fn.(params, input) end)
      end

    predict_tasks
    |> Enum.map(fn predict_task -> Task.await(predict_task) end)
    |> Nx.concatenate(axis: axis)
  end

  defp single_head_model(args) do
    shape = Keyword.get(args, :shape, {nil, 6, 3})
    d_out = Keyword.get(args, :d_out, 2)
    d_in = Keyword.get(args, :d_in, 3)
    dropout_rate = Keyword.get(args, :dropout_rate, 0.5)
    
    Axon.input("sequence", shape: shape)
    |> Attention.causal(d_out: d_out, d_in: d_in, name: "causal")
    |> Axon.dropout(rate: dropout_rate)
  end

  defp build_single_head_model(model, args) do
    shape = Keyword.get(args, :shape, {1,6,3})
    model_state = Keyword.get(args, :model_state, %{})
    {init_fn, predict_fn} = Axon.build(model)
    template = Nx.template(shape, :f32)
    params = init_fn.(template, model_state)
    {params, init_fn, predict_fn}
  end
end
heads = MultiHeadAttentionWrapper.multi_head_model(2)
models = MultiHeadAttentionWrapper.build_heads(heads)
MultiHeadAttentionWrapper.predict(models, batch)

3.6.2 Implementing multi-head attention with weight splits

MultiHeadAttention module integrates the multi-head functionality within a single class. It splits the input into multiple heads by reshaping the projected query, key, and value tensors and then combines the results from these heads after computing attention.

defmodule MultiHeadAttention do
  import Nx.Defn


  def model(args \\ []) do
    shape = Keyword.get(args, :shape, {nil, 6, 3})
    d_out = Keyword.get(args, :d_out, 4)
    d_in = Keyword.get(args, :d_in, 3)
    num_heads = Keyword.get(args, :num_heads, 2)
    dropout_rate = Keyword.get(args, :dropout_rate, 0.5)
    
    Axon.input("sequence", shape: shape)
    |> causal(d_out: d_out, d_in: d_in, num_heads: num_heads, name: "causal")
    |> Axon.dropout(rate: dropout_rate)
  end

  def build_model(model, args) do
    shape = Keyword.get(args, :shape, {1,6,3})
    model_state = Keyword.get(args, :model_state, %{})
    {init_fn, predict_fn} = Axon.build(model)
    template = Nx.template(shape, :f32)
    params = init_fn.(template, model_state)
    {params, init_fn, predict_fn}
  end

  def causal(%Axon{} = input, opts \\ []) do
    opts = Keyword.validate!(opts, [:name, :d_in, :d_out, :num_heads])
    head_dim = div(opts[:d_out], opts[:num_heads])
    w_query = Axon.param("w_query", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_key = Axon.param("w_key", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_value = Axon.param("w_value", fn _ -> {opts[:d_in], opts[:d_out]} end)
    out_proj = Axon.param("out_proj", fn _ -> {opts[:d_out], opts[:d_out]} end)

    Axon.layer(
      #&attention_impl/7,
      &amp;attention_impl/6,
      #[input, w_query, w_key, w_value, head_dim, opts[:num_heads]],
      [input, w_query, w_key, w_value, out_proj],
      name: opts[:name],
      op_name: :causal_attention,
      head_dim: head_dim,
      num_heads: opts[:num_heads]
    )
  end

  #defnp attention_impl(input, w_query, w_key, w_value, head_dim, num_heads, _opts \\ []) do
  defnp attention_impl(input, w_query, w_key, w_value, out_proj, opts \\ []) do
    {b, num_tokens, _d_in} = Nx.shape(input)
    keys = Nx.dot(input, w_key)
    queries = Nx.dot(input, w_query)
    values = Nx.dot(input, w_value)
    d_k = Nx.axis_size(keys, -1)

    keys_reshaped = 
      keys
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])
    
    queries_reshaped = 
      queries
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])
    
    values_reshaped = 
      values
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])

    attn_score =
      keys_reshaped
      |> Nx.transpose(axes: [0, 1, 3, 2])
      |> then(&amp;Nx.dot(queries_reshaped, [3], [0, 1], &amp;1, [2], [0, 1]))

    simple_mask =
      attn_score
      |> then(&amp;Nx.broadcast(Nx.Constants.infinity(), &amp;1))
      |> Nx.triu(k: 1)

    masked = Nx.multiply(simple_mask, -1) |> Nx.add(attn_score)

    attn_weights =
      masked
      |> Nx.divide(Nx.pow(d_k, 0.5))
      |> Axon.Activations.softmax(axis: -1)
    
    context_vec =
      attn_weights
      |> Nx.dot([3], [0, 1], values_reshaped, [2], [0, 1])
      |> Nx.transpose(axes: [0, 2, 1, 3])

    context_vec
    |> Nx.reshape({b, num_tokens, opts[:num_heads] * opts[:head_dim]})
    |> Nx.dot(out_proj)
  end
end
model = MultiHeadAttention.model()
{params, init_fn, predict_fn} = MultiHeadAttention.build_model(model, [])
result = predict_fn.(params, batch)
model = MultiHeadAttention.model(d_out: 2)
{params, init_fn, predict_fn} = MultiHeadAttention.build_model(model, [])
result = predict_fn.(params, batch)

Summary

  • Attention mechanisms transform input elements into enhanced context vector representations that incorporate information about all inputs.
  • A self-attention mechanism computes the context vector representation as a weighted sum over the inputs.
  • In a simplified attention mechanism, the attention weights are computed via dot products.
  • A dot product is a concise way of multiplying two vectors element-wise and then summing the products.
  • Matrix multiplications, while not strictly required, help us implement computations more efficiently and compactly by replacing nested for loops.
  • In self-attention mechanisms used in LLMs, also called scaled-dot product attention, we include trainable weight matrices to compute intermediate transformations of the inputs: queries, values, and keys.
  • When working with LLMs that read and generate text from left to right, we add a causal attention mask to prevent the LLM from accessing future tokens.
  • In addition to causal attention masks to zero-out attention weights, we can add a dropout mask to reduce overfitting in LLMs.
  • The attention modules in transformer-based LLMs involve multiple instances of causal attention, which is called multi-head attention.
  • We can create a multi-head attention module by stacking multiple instances of causal attention modules.
  • A more efficient way of creating multi-head attention modules involves batched matrix multiplications.