Note
Go to the end to download the full example code
Fused Linear JSD Example#
This example demonstrates how to implement a memory-efficient fused linear + JSD kernel using Helion. The fusion avoids materializing the full [batch, vocab] logits tensor by: 1. Chunking the batch dimension 2. Computing logits per-chunk via matmul 3. Computing JSD loss and grad_logits per-chunk 4. Converting grad_logits back to grad_input (batch x hidden) via matmul
Imports#
Helion JSD Kernel#
This kernel computes JSD loss and gradient w.r.t. logits for a single chunk. The gradient is computed during the forward pass to avoid storing full logits.
@helion.kernel()
def jsd_kernel(
beta: float,
ignore_index: int,
temperature: float,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute JSD loss and gradient w.r.t. student logits for a chunk.
Args:
beta: Interpolation coefficient for JSD (0 = student only, 1 = teacher only)
ignore_index: Label index to ignore (unused in label-free version)
temperature: Temperature for softmax
student_logits: Student model logits [chunk_size, vocab_size]
teacher_logits: Teacher model logits [chunk_size, vocab_size]
Returns:
loss: Per-sample JSD loss [chunk_size]
grad_student_logits: Gradient w.r.t. student logits [chunk_size, vocab_size]
"""
chunk_size, vocab_size = student_logits.size()
loss = student_logits.new_empty(chunk_size, dtype=torch.float)
grad_student_logits = torch.empty_like(student_logits, dtype=torch.float)
for tile in hl.tile(chunk_size):
# Scale by temperature
student_scaled = student_logits[tile, :] / temperature
teacher_scaled = teacher_logits[tile, :] / temperature
# Compute softmax probabilities
student_prob = torch.softmax(student_scaled.to(torch.float), dim=-1)
teacher_prob = torch.softmax(teacher_scaled.to(torch.float), dim=-1)
# Compute log probabilities for KL divergence
student_log_prob = torch.log_softmax(student_scaled.to(torch.float), dim=-1)
teacher_log_prob = torch.log_softmax(teacher_scaled.to(torch.float), dim=-1)
# Compute mixture distribution: m = (1-beta)*student + beta*teacher
m = (1 - beta) * student_prob + beta * teacher_prob
log_m = torch.log(m)
# JSD = (1-beta) * KL(student || m) + beta * KL(teacher || m)
# KL(p || q) = sum(p * (log(p) - log(q)))
student_kl = (student_prob * (student_log_prob - log_m)).sum(dim=-1)
teacher_kl = (teacher_prob * (teacher_log_prob - log_m)).sum(dim=-1)
jsd_loss = (1 - beta) * student_kl + beta * teacher_kl
loss[tile] = jsd_loss
# Gradient of JSD w.r.t. student logits
# d(JSD)/d(student_logits) = (1-beta)/T * (student_prob - m)
# This comes from the chain rule through softmax
grad_tile = ((1 - beta) / temperature) * (student_prob - m)
grad_student_logits[tile, :] = grad_tile
return loss, grad_student_logits
Backwards-compatible kernel for unit tests#
This kernel computes JSD loss from logits (without gradients). For memory-efficient fused linear + JSD, use fused_linear_jsd_fwd instead.
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
def fused_linear_jsd_kernel(
beta: float,
ignore_index: int,
temperature: float,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
) -> torch.Tensor:
"""
Compute JSD loss from pre-computed logits.
This kernel is kept for unit testing the JSD computation in isolation.
For memory-efficient fused linear + JSD, use fused_linear_jsd_fwd instead.
"""
batch_size, _vocab_size = student_logits.size()
loss = student_logits.new_empty(batch_size, dtype=torch.float)
for tile in hl.tile(batch_size):
# Scale by temperature
student_scaled = student_logits[tile, :] / temperature
teacher_scaled = teacher_logits[tile, :] / temperature
# Compute softmax probabilities
student_prob = torch.softmax(student_scaled.to(torch.float), dim=-1)
teacher_prob = torch.softmax(teacher_scaled.to(torch.float), dim=-1)
# Compute log probabilities for KL divergence
student_log_prob = torch.log_softmax(student_scaled.to(torch.float), dim=-1)
teacher_log_prob = torch.log_softmax(teacher_scaled.to(torch.float), dim=-1)
# Compute mixture distribution: m = (1-beta)*student + beta*teacher
m = (1 - beta) * student_prob + beta * teacher_prob
log_m = torch.log(m)
# JSD = (1-beta) * KL(student || m) + beta * KL(teacher || m)
student_kl = (student_prob * (student_log_prob - log_m)).sum(dim=-1)
teacher_kl = (teacher_prob * (teacher_log_prob - log_m)).sum(dim=-1)
jsd_loss = (1 - beta) * student_kl + beta * teacher_kl
loss[tile] = jsd_loss
return (loss / batch_size).sum()
Autograd Function#
This implements the memory-efficient forward/backward using chunking.
class FusedLinearJSDFunction(torch.autograd.Function):
"""
Memory-efficient fused linear + JSD using chunking.
The forward pass computes gradients w.r.t. inputs, avoiding the need to
store the full [batch, vocab] logits tensor for backward.
"""
@staticmethod
def forward( # pyrefly: ignore [bad-override]
ctx: Any, # noqa: ANN401
student_input: Tensor,
student_weight: Tensor,
teacher_input: Tensor,
teacher_weight: Tensor,
beta: float,
ignore_index: int,
temperature: float,
) -> Tensor:
"""
Forward pass with chunking to reduce memory usage.
Instead of materializing full [batch, vocab] logits, we:
1. Process in chunks along the batch dimension
2. Compute logits per-chunk: chunk_logits = chunk_input @ weight.T
3. Compute JSD loss and grad_logits per-chunk
4. Convert grad_logits to grad_input: grad_input = grad_logits @ weight
This way we only ever have [chunk_size, vocab] in memory, not [batch, vocab].
"""
batch_size, hidden_dim = student_input.shape
vocab_size = student_weight.shape[0]
# Calculate chunk size to balance memory vs overhead
# Larger chunks = more memory but fewer kernel launches
# Use heuristic similar to Liger: chunk based on vocab/hidden ratio
inc_factor = max(1, (vocab_size + hidden_dim - 1) // hidden_dim)
chunk_size = 1
while chunk_size * 2 <= batch_size // inc_factor:
chunk_size *= 2
chunk_size = max(1, min(chunk_size, batch_size))
num_chunks = (batch_size + chunk_size - 1) // chunk_size
# Allocate outputs
total_loss = torch.tensor(0.0, device=student_input.device, dtype=torch.float)
grad_student_input = torch.zeros_like(student_input, dtype=torch.float)
grad_student_weight = torch.zeros_like(student_weight, dtype=torch.float)
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, batch_size)
# Get chunk of inputs
student_input_chunk = student_input[start_idx:end_idx]
teacher_input_chunk = teacher_input[start_idx:end_idx]
# Compute logits for this chunk only: [chunk, vocab]
student_logits_chunk = student_input_chunk @ student_weight.T
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.T
# Compute JSD loss and gradient w.r.t. logits
loss_chunk, grad_logits_chunk = jsd_kernel(
beta,
ignore_index,
temperature,
student_logits_chunk,
teacher_logits_chunk,
)
# Accumulate loss
total_loss = total_loss + loss_chunk.sum()
# Convert grad_logits [chunk, vocab] to grad_input [chunk, hidden]
# grad_input = grad_logits @ weight
grad_student_input[start_idx:end_idx] = grad_logits_chunk @ student_weight
# Accumulate weight gradient: grad_weight += grad_logits.T @ input
grad_student_weight = (
grad_student_weight + grad_logits_chunk.T @ student_input_chunk.float()
)
# Normalize loss by batch size
total_loss = total_loss / batch_size
# Save for backward
ctx.save_for_backward(grad_student_input, grad_student_weight)
return total_loss
@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Any, # noqa: ANN401
) -> tuple[Tensor, Tensor, None, None, None, None, None]:
"""
Backward pass: rescale the precomputed gradients by grad_output.
"""
(grad_output,) = grad_outputs
grad_student_input, grad_student_weight = ctx.saved_tensors
# Scale gradients by upstream gradient
grad_student_input = grad_output * grad_student_input
grad_student_weight = grad_output * grad_student_weight
# Return gradients for: student_input, student_weight, teacher_input,
# teacher_weight, beta, ignore_index, temperature
return grad_student_input, grad_student_weight, None, None, None, None, None
Forward Function#
def fused_linear_jsd_fwd(
beta: float,
ignore_index: int,
temperature: float,
student_weight: torch.Tensor,
teacher_weight: torch.Tensor,
student_input: torch.Tensor,
teacher_input: torch.Tensor,
) -> torch.Tensor:
"""
Compute fused linear + JSD loss.
This is the main entry point that uses the memory-efficient autograd function.
"""
return FusedLinearJSDFunction.apply(
student_input,
student_weight,
teacher_input,
teacher_weight,
beta,
ignore_index,
temperature,
)
Benchmark Entry Point Function#
def fused_linear_jsd_fwd_tritonbench(
tb_op: object,
student_input: torch.Tensor,
teacher_input: torch.Tensor,
label: torch.Tensor | None = None,
) -> Callable[[], torch.Tensor]:
assert label is None
# pyrefly: ignore [missing-attribute]
baseline_op = tb_op.baseline_op
beta = baseline_op.jsd.beta
ignore_index = baseline_op.jsd.ignore_index
temperature = baseline_op.temperature
student_weight = baseline_op.student_lin.weight
teacher_weight = baseline_op.teacher_lin.weight
return lambda: fused_linear_jsd_fwd(
beta,
ignore_index,
temperature,
student_weight,
teacher_weight,
student_input,
teacher_input,
)
Reference Implementation#
def fused_linear_jsd_pytorch(
beta: float,
ignore_index: int,
temperature: float,
student_weight: torch.Tensor,
teacher_weight: torch.Tensor,
student_input: torch.Tensor,
teacher_input: torch.Tensor,
) -> torch.Tensor:
"""Reference PyTorch implementation for verification."""
student_logits = student_input @ student_weight.T
teacher_logits = teacher_input @ teacher_weight.T
# Scale by temperature
student_scaled = student_logits / temperature
teacher_scaled = teacher_logits / temperature
# Compute probabilities
student_prob = torch.softmax(student_scaled.to(torch.float), dim=-1)
teacher_prob = torch.softmax(teacher_scaled.to(torch.float), dim=-1)
# Compute log probabilities
student_log_prob = torch.log_softmax(student_scaled.to(torch.float), dim=-1)
teacher_log_prob = torch.log_softmax(teacher_scaled.to(torch.float), dim=-1)
# Mixture distribution
m = (1 - beta) * student_prob + beta * teacher_prob
log_m = torch.log(m)
# JSD = (1-beta) * KL(student || m) + beta * KL(teacher || m)
student_kl = (student_prob * (student_log_prob - log_m)).sum(dim=-1)
teacher_kl = (teacher_prob * (teacher_log_prob - log_m)).sum(dim=-1)
loss = (1 - beta) * student_kl + beta * teacher_kl
return loss.mean()
Verification Function#
def check(m: int, n: int, k: int) -> None:
student_input = torch.rand([m, n], device=DEVICE, dtype=torch.float)
teacher_input = torch.rand([m, n], device=DEVICE, dtype=torch.float)
student_weight = torch.rand([k, n], device=DEVICE, dtype=torch.float)
teacher_weight = torch.rand([k, n], device=DEVICE, dtype=torch.float)
run_example(
fused_linear_jsd_fwd,
fused_linear_jsd_pytorch,
(0.5, -100, 1.0, student_weight, teacher_weight, student_input, teacher_input),
)
Main Function#
def main() -> None:
check(1024, 4096, 128256)
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)