Explaining Llama. Speeding up Llama: How to Combine to… | by Shitanshu Bhushan | January, 2025
Accelerating the Llama: A hybrid approach to attention
In this article, we will see how to replace softmax focus in Llama-3.2-1B with hybrid focus that combines softmax sliding window and direct focus. This application will help us better understand the growing interest in sequential attention research, while exploring its limitations and possible future directions.
This app contains the following features:
This article will be a recreation of the LoLCATs paper using Llama 3.2 1B, where we will replace 50% of the attention layers in the pre-trained Llama model. The article consists of four main parts:
- Hybrid Attention Block
- Transferring Attention
- LORA switch
- Testing
The main goal of this article is that we can somehow change the focus of softmax on already trained models so that we can speed up the estimation while not losing too much accuracy. If we can achieve this we can significantly reduce the cost of using LLMs!
Let's see what the Llama-3.2-1B model looks like:
Since we have 16 repeatable decoder blocks, our focus will be on self_atn part so the goal of this section is to understand how the LlamaSdpAttention block works! Let's see what LlamaSdpAttention means:
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
You can check what this function looks like using the following code:
import inspectattention_layer = model.model.layers[0].self_attn
print(inspect.getsource(attention_layer.__class__))
Let's go through the main parts of this code and understand what each part does and see where we need to make a change,
Let's take a dummy input for the scenario [2,4,2048] → [batch_size, seq_len, embedding dimension]. Llama uses a multi-headed attn with 32 heads.
Block 1:
After proj → query_states is a tensor of [2,4,2048]key_states is a tensor of [2,4,512] and value_states is a tensor of [2,4,512].
After viewing it and submitting it it says: query_states → [2,32,4,64] key_states → [2,8,4,64] value_states → [2,8,4,64]
Here 64 is the size of embedding, key and value has heads as 8 because llama uses key-value groups where basically in total 32 heads, 4 head groups share same key states and value_states among total 32 heads.
Block 2:
In this block we just use location encoding especially llama using Rotary Position Embeddings (RoPE). I won't go into the details of why this is needed but you can read the following article to get a better idea:
Block 3:
Here we simply use the repeating_kv function that simply repeats the kv value in 4 groups, and we use past_key_value to be able to use the computerized kv values so that we don't have to recalculate them for efficiency.
Block 4:
Block 4 carries two main preparatory steps to be taken into account: setting the causal mask to ensure that the tokens only look at previous locations, and optimizing the memory structure with compact tensors for optimal GPU performance.
Block 5:
This is where we use softmax attention – the part we will change in our implementation.
Block 6:
The output of attention will be the shape tensor [2, 32, 4, 64]. We convert it back to [2, 4, 2048] and use the final output projection.
And that's a trip to Llama's attention!
So now let's look at our HybridAttention block:
class HybridAttention(LlamaSdpaAttention):
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx=layer_idx)
self.window_size = 64
#self.layer_idx = layer_idx# Initialize learnable factors
# Create one factor pair per attention head
num_heads = config.num_attention_heads
self.window_factors = torch.nn.Parameter(torch.ones(1, num_heads, 1, 1) * 0.5)
self.linear_factors = torch.nn.Parameter(torch.ones(1, num_heads, 1, 1) * 0.5)
self.factor_activation = torch.nn.Sigmoid()
def sliding_window_attention(self, query_states, key_states, value_states, window_size, window_factor):
"""Compute sliding window attention"""
batch_size, num_heads, seq_len, head_dim = query_states.shape
key_windows = F.pad(key_states, (0, 0, window_size - 1, 0), value=0)
key_windows = key_windows.unfold(2, window_size, 1)
value_windows = F.pad(value_states, (0, 0, window_size - 1, 0), value=0)
value_windows = value_windows.unfold(2, window_size, 1)
attn_weights = torch.einsum('bhld,bhldw->bhlw', query_states, key_windows) * (head_dim ** -0.5)
attn_weights = torch.where(attn_weights == 0,
torch.tensor(-float('inf'), device=attn_weights.device),
attn_weights)
# Apply learnable window factor (with sigmoid to ensure positivity)
attn_weights = self.factor_activation(window_factor) * F.softmax(attn_weights, dim=-1)
attn_output = torch.einsum('bhlw,bhldw->bhld', attn_weights, value_windows)
sum_weights = attn_weights.sum(dim=-1, keepdim=True)
return attn_output, sum_weights
def linear_attention(self, query_states, key_states, value_states, window_size, linear_factor):
"""Compute linear attention with cumsum"""
def feature_map(x):
return F.elu(x) + 1
query_prime = feature_map(query_states)
key_prime = feature_map(key_states)
key_prime = F.pad(key_prime, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]
value_padded = F.pad(value_states, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]
# Compute KV
kv = torch.einsum('bhlf,bhld->bhlfd', key_prime, value_padded)
# Apply learnable linear factor (with sigmoid to ensure positivity)
qkv = self.factor_activation(linear_factor) * torch.einsum('bhlf,bhlfd->bhld',
query_prime,
kv.cumsum(dim=2))
sum_k = key_prime.cumsum(dim=2)
sum_qk = self.factor_activation(linear_factor) * torch.einsum('bhld,bhld->bhl',
query_prime,
sum_k)[..., None]
sum_qk = torch.where(sum_qk == 0, torch.tensor(1e-12, device=sum_qk.device), sum_qk)
return qkv, sum_qk
def hybrid_attention(self, query_states, key_states, value_states):
"""Combine sliding window and linear attention with learnable factors"""
qkv_window, sum_window = self.sliding_window_attention(
query_states, key_states, value_states,
self.window_size, self.window_factors
)
qkv_linear, sum_linear = self.linear_attention(
query_states, key_states, value_states,
self.window_size, self.linear_factors
)
output = (qkv_window + qkv_linear) / (sum_window + sum_linear)
return output
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = self.hybrid_attention(
query_states,
key_states,
value_states
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
We only made one change to forward(), we replaced the 5th block with the following:
attn_output = self.hybrid_attention(
query_states,
key_states,
value_states
)
Basically we divided the way of attention sliding window again line attention blocks.
Sliding window attention:
def sliding_window_attention(self, query_states, key_states, value_states, window_size, window_factor):
"""Compute sliding window attention"""
batch_size, num_heads, seq_len, head_dim = query_states.shapekey_windows = F.pad(key_states, (0, 0, window_size - 1, 0), value=0)
key_windows = key_windows.unfold(2, window_size, 1)
value_windows = F.pad(value_states, (0, 0, window_size - 1, 0), value=0)
value_windows = value_windows.unfold(2, window_size, 1)
attn_weights = torch.einsum('bhld,bhldw->bhlw', query_states, key_windows) * (head_dim ** -0.5)
attn_weights = torch.where(attn_weights == 0,
torch.tensor(-float('inf'), device=attn_weights.device),
attn_weights)
# Apply learnable window factor (with sigmoid to ensure positivity)
attn_weights = self.factor_activation(window_factor) * F.softmax(attn_weights, dim=-1)
attn_output = torch.einsum('bhlw,bhldw->bhld', attn_weights, value_windows)
sum_weights = attn_weights.sum(dim=-1, keepdim=True)
return attn_output, sum_weights
For a deeper understanding of window focus concepts, I recommend you refer to this paper:
The idea I used here is that instead of counting the attention of all key-value pairs together (where each token looks at every other token), we break it into windows of size 'w' and calculate the attention of each window. Using this in the code above, the time complexity drops from O(n²) to O(n*w), since each token only needs to pay attention to w tokens instead of all n tokens. It can be made even better by using concepts like sink and making only a window for the last w tokens that I may use in future updates.
Line Attention:
def linear_attention(self, query_states, key_states, value_states, window_size, linear_factor):
"""Compute linear attention with cumsum"""
def feature_map(x):
return F.elu(x) + 1query_prime = feature_map(query_states)
key_prime = feature_map(key_states)
key_prime = F.pad(key_prime, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]
value_padded = F.pad(value_states, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]
# Compute KV
kv = torch.einsum('bhlf,bhld->bhlfd', key_prime, value_padded)
# Apply learnable linear factor (with sigmoid to ensure positivity)
qkv = self.factor_activation(linear_factor) * torch.einsum('bhlf,bhlfd->bhld',
query_prime,
kv.cumsum(dim=2))
sum_k = key_prime.cumsum(dim=2)
sum_qk = self.factor_activation(linear_factor) * torch.einsum('bhld,bhld->bhl',
query_prime,
sum_k)[..., None]
sum_qk = torch.where(sum_qk == 0, torch.tensor(1e-12, device=sum_qk.device), sum_qk)
return qkv, sum_qk
To get line attention, I'm using a very simple feature map of elu(x) + 1 but the main part to notice is the first pad that gets done. The idea here is that we can only use line attention for the first time [sequence length — window size] since we already have a sliding window to track the latest context.
The combination of these two types of attention becomes our new hybrid attention and we use it window_factor again row_factor as readable parameters that control how much each type of attention affects the final output.