Note
Go to the end to download the full example code
Gated Delta Net Fwd H Kernel#
This code implements a fwd_h kernel as used in gated delta net
Imports#
from __future__ import annotations
import math
from typing import Callable
import torch
import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl
Helion Kernel Implementation#
@helion.kernel()
def helion_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, seqlen, nheads, dhead = k.shape
dhead = hl.specialize(dhead)
chunk_size = hl.specialize(chunk_size)
dstate = u.shape[-1]
acc_dtype = torch.float32
dtype = k.dtype
nchunks = (seqlen + chunk_size - 1) // chunk_size
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
block_v = hl.register_block_size(dstate)
for i_b, i_h in hl.grid([batch, nheads]):
for tile_v in hl.tile(dstate, block_size=block_v):
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
for t_i in hl.tile(seqlen, block_size=chunk_size):
h[i_b, t_i.id, i_h, :, tile_v] = b_h.to(dtype)
b_w = w[i_b, t_i, i_h, :]
c_h = b_h.to(dtype)
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
p_v = u[i_b, t_i, i_h, tile_v].to(acc_dtype)
b_v = p_v - b_v
m_t = t_i.index < seqlen
t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
b_g_last = g[i_b, t_i_last, i_h].to(acc_dtype)
b_g = g[i_b, t_i, i_h].to(acc_dtype)
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last
b_v = b_v.to(dtype)
p_k = k[i_b, t_i, i_h, :]
b_h = hl.dot(p_k.T, b_v, acc=b_h)
return h
def helion_gdn_fwd_h_tb(
tb_obj: object,
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor,
chunk_size: int,
) -> Callable[[], torch.Tensor]:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size)
Reference Function#
def ref_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
batch, seqlen, nheads, dhead = k.shape
expand_v = u.shape[-1] // dhead
nchunks = (seqlen + chunk_size - 1) // chunk_size
acc_dtype = torch.float32
dtype = k.dtype
h = torch.empty(
batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device
)
b_h = torch.zeros(
batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device
)
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
for i_t in range(nchunks):
h[:, i_t, :, :, :] = b_h.to(dtype)
b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
c_h = b_h.to(dtype).to(acc_dtype)
b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
b_v = p_v - b_v
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
b_g_last = g[:, last_idx, :].to(acc_dtype)
b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
b_v *= torch.where(
m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0
).unsqueeze(-1)
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
b_v = b_v.to(dtype).to(acc_dtype)
p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
return h
Testing Function#
def test(
batch: int,
nheads: int,
seqlen: int,
chunk_size: int,
dhead: int,
dstate: int,
dtype: torch.dtype = torch.float16,
) -> None:
k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE)
k = torch.nn.functional.rms_norm(k, [dhead])
w = torch.randn(
batch,
seqlen // chunk_size,
chunk_size,
nheads,
dhead,
dtype=torch.float32,
device=DEVICE,
)
# w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,))
wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
w = (
w.permute(0, 1, 3, 2, 4)
.reshape(batch, seqlen, nheads, dhead)
.to(torch.bfloat16)
)
u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE)
u = torch.nn.functional.rms_norm(u, [dstate])
g = torch.cumsum(
0.5
* math.log(1 / dhead)
* torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE),
dim=1,
)
args = (k, w, u, g, chunk_size)
run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
Main Function#
def main() -> None:
"""
Main entry point that runs the attention kernel test with specific parameters.
"""
test(8, 80, 4096, 256, 64, 128)
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)