TL;DR: I develop a IO-aware kernel for the RG-LRU in Triton. When training on sequences greater than 4K tokens, the kernel is faster than PyTorch’s Flash Attention 2 implementation.
Hawk is an RNN proposed by Google Deepmind in “Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models”.1 At its core is the Real-Gated Linear Recurrent Unit (RG-LRU) layer, a linear and diagonal recurrence. A linear and diagonal recurrence means each element in the hidden state only depends on its corresponding element from the previous timestep.
The equations for computing the recurrence are:
\[\begin{align*} r_t &= \sigma(W_a x_t + b_a) \\ i_t &= \sigma(W_x x_t + b_x) \\ a_t &= a^{C r_t} \\ h_t &= a_t \odot h_{t-1} + \sqrt{1 - a_{t}^2} \odot (i_t \odot x_t) \end{align*}\]Where \(x_t\) is of shape [B, L, R]
, \(h_t\) is of shape [B, R]
and \(C\) is a trainable parameter of size R
. In PyTorch psuedocode, the forward pass is
"""Example RG-LRU forward pass:
B: batch size
L: sequence length
D: model dimension
R: Recurrent dimension
"""
# forward pass
gate_x_BLR = sigmoid(x_BLR @ W_xRR)
gate_a_BLR = sigmoid(x_BLR @ W_aRR)
log_a_BLR = -8.0 * gate_a_BLR * softplus(C_R)
a_BLR = exp(log_a_BLR)
a_square_BLR = exp(2 * log_a_BLR)
gated_x_BLR = x_BLR * gate_x_BLR
normalized_x_BLR = gated_x_BLR * sqrt(1 - a_square_BLR)
# compute linear recurrence -> fused into single kernel in Pallas or Triton
h_t_BR = zeros(B,R)
for i in range(0,L):
# load a_BLR[:,i] and normalized_x_BLR[:,i] to SRAM
h_t_BR = a_BLR[:,i]*h_t_BR + normalized_x_BLR[:,i]
# write h_t_BR to HBM
As the recurrence itself performs no matrix-multiplications and only elementwise operations, it follows that the runtime of this kernel will be memory-bound on any GPU or TPU. The authors develop a custom Pallas kernel to compute this linear recurrence, which can be viewed on DM’s “RecurrentGemma” Github here. On the systems side, an interesting takeaway from this paper is that a “naive” fused linear scan kernel actually outperforms the more parallelizeable associative scan kernel. Others have noticed similar results for GPUs, so the original kernel I wrote was a simple linear scan of the form \(h_t = \alpha_t \odot h_{t-1} + \beta_t\) where \(\alpha_t\) and \(\beta_t\) were the respective recurrent parameters, pre-computed.
Since I don’t have access to TPUs I re-implemented Hawk and the linear scan kernel it uses in PyTorch and Triton2. You can view the initial linear scan kernel here.
One thing that might stand out is that before we even get to the actual linear recurrence, we are performing many elementwise operations to get the recurrent parameters ready. While its likely that XLA or torch.compile
could fuse together some of these operations together, we can do a lot better than this by fully re-writing the kernel to use two techniques: manual operator fusion and activation recomputation.
The first optimization we apply is to fuse all elementwise operations into the linear scan kernel itself. In pseudocode, our new RG-LRU computation would look as follows.
# forward pass
gate_x_no_act_BLR = x_BLR @ W_xRR
gate_a_no_act_BLR = x_BLR @ W_aRR
# compute linear recurrence -> fused into single kernel in Triton
h_t_BR = zeros(B,R)
for i in range(0,L):
# load gate_x_no_act_BLR[:,i], gate_a_no_act_BLR[:,i], C_R to SRAM
# Start of computation in SRAM
gate_x_BR = sigmoid(gate_x_no_act_BLR[:,i])
gate_a_BR = sigmoid(gate_a_no_act_BLR[:,i])
log_a_BR = -8.0 * gate_a_BR * softplus(C_R)
a_BR = exp(log_a_BR)
a_square_BR = exp(2 * log_a_BR)
gated_x_BR = x_BLR[:,i] * gate_x_BR
normalized_x_BR = gated_x_BR * sqrt(1 - a_square_BR)
h_t_BR = a_BLR*h_t_BR + normalized_x_BR
# End of computation in SRAM
# write h_t_BR to HBM
Assuming half-precision weights and activations, the unfused RG-LRU layer reads 26BLR + 4RR + 2R
bytes and writes 22BLR
bytes for a total of 48BLR + 4RR + 2R
bytes. The fused RG-LRU layer reads 4BLR + 4RR + 2R
bytes and writes 2BLR
bytes for a total of 6BLR + 4RR + 2R
. For all models considered, L will be greater than or equal to R and B > 1. So while the memory complexity for both kernels is still \(\mathcal{O}(BLR)\), the fused kernel reduces the constant factor from \(48\) to \(6\) leading to the fused kernel reading and writing roughly one eighth the bytes of the unfused kernel. As the runtime of this layer is memory bandwidth bound, we can expect large speedups. When implemented in Triton, this speedup over the initial kernel is large and for contexts greater than 6K tokens, we are even faster than PyTorch’s Flash Attention 2 (called via scaled_dot_product_attention
):
All benchmarks were performed on a single 40GB A100 from Lambda Labs with PyTorch 2.5.1 and Triton 3.1.0.
While fusing essentially all the operations but the matmuls into the forward recurrence works well for inference (and really, only the first step of the inference, since this is… you know… an RNN), it is not possible to use this technique during training since we are not storing any activations needed to compute the gradients of the trainable parameters.
To use this fused kernel for training, we need utilize a technique called Activation Recomputation which works by “recomputing” certain intermediate activations needed for gradient computations in the backward pass instead of saving them in the forward pass, this is one of the main optimizations behind Flash Attention3. Combining operator fusion with activation recomputation is not a new idea, while it was used for the kernels in Flash Attention, He et al.4 also show that in some more general settings, the combination of fusion and recomputation can actually speed up the forward + backward runtime of models due to the decreased memory throughput, and that selective fusion and recomputation can be determined at a graph’s compile time.
For the RG-LRU, this technique is simple to apply manually. Instead of viewing the linear recurrence as one big computation graph, I found it mentally easier to split up the gradient computation into two phases: the first phase computes the linear recurrence parameters:
\[\begin{align*} \alpha_t &= f(r_t)\\ \beta_t &= g(r_t,i_t,x_t) \end{align*}\]Where \(f\) and \(g\) are computed as:
\[\begin{align*} f(r_t) &= -8.0 * \sigma(r_t) * \text{softplus}(C)\\ g(r_t,i_t,x_t) &= \sqrt{1-f(r_t)^2}\odot x_t \odot \sigma(i_t) \end{align*}\]And the second phase simply computes a standard diagonal linear recurrence:
\[\begin{align*} h_t = \alpha_t \odot h_{t-1} + \beta_t \end{align*}\]Viewing the kernel this way allows us to manually compute the gradients of \(a_{t}\) and \(\beta_t\) with respect to the loss5 first as:
\[\begin{align*} \nabla_{h_T} L &= \frac{\partial L}{\partial h_T} \\ \nabla_{h_t} L &= \frac{\partial h_{t+1}}{\partial h_t} \odot \nabla_{h_{t+1}} L + \frac{\partial L}{\partial h_t} \\ &= \alpha_{t+1} \odot \nabla_{h_{t+1}} L + \frac{\partial L}{\partial h_t} \\ \nabla_{\alpha_t} L &= \frac{\partial h_t}{\partial \alpha_t} \odot \nabla_{h_t} L = h_{t-1} \odot \nabla_{h_t} L \\ \nabla_{\beta_t} L &= \nabla_{h_t} L \\ \nabla_{h_0} L &= \frac{\partial h_1}{\partial h_0} \odot \nabla_{h_1} L = \alpha_1 \odot \nabla_{h_1} L \end{align*}\]From here, the gradient computations for the parameters we optimize, \(x_t\), \(r_t\), \(i_t\), and \(C\) are simple elementwise gradient computations and left as an exercise 😉.
The forward and backward passes for this kernel were implemented in Triton and using the same benchmark setup as above, (R = 1024, B = 8
), our fully-fused kernel with activation recomputation is fast:
All benchmarks were performed on a single 40GB A100 from Lambda Labs with PyTorch 2.5.1 and Triton 3.1.0.
Similar to the forward pass kernel, for contexts greater than 4K tokens, we are once again faster than Flash Attention 2.
Using this kernel I trained a 12L/768D Hawk model on 100B tokens of OpenWebText (GPT2-small config):
If you’re interested in reading through the full Triton kernel it’s here and if you’re interested in my implementation of Hawk its here!
De, Soham, et al. “Griffin: Mixing gated linear recurrences with local attention for efficient language models.” arXiv preprint arXiv:2402.19427 (2024). ↩
Tillet, Philippe, Hsiang-Tsung Kung, and David Cox. “Triton: an intermediate language and compiler for tiled neural network computations.” Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages. 2019. ↩
Dao, Tri, et al. “Flashattention: Fast and memory-efficient exact attention with io-awareness.” Advances in Neural Information Processing Systems 35 (2022): 16344-16359. ↩
He, Horace, and Shangdi Yu. “Transcending runtime-memory tradeoffs in checkpointing by being fusion aware.” Proceedings of Machine Learning and Systems 5 (2023): 414-427. ↩
Martin, Eric, and Chris Cundy. “Parallelizing linear recurrent neural nets over sequence length.” arXiv preprint arXiv:1709.04057 (2017). ↩