Wave Lang Examples
See how Wave Lang makes GPU programming simple with these practical examples. From basic operations to complex ML kernels.
Vector Addition
The classic "Hello World" of GPU programming. Add two vectors element-wise with automatic parallelization and memory optimization.
import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw
# Define symbolic dimensions
M = tkl.sym.M
N = tkl.sym.N
@tkw.wave(constraints)
def vector_add(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
"""Add two vectors element-wise."""
lhs = tkw.read(a)
rhs = tkw.read(b)
res = lhs + rhs
tkw.write(res, c)
import torch
# Define constraints for the kernel
constraints = [
tkw.HardwareConstraint(vector_shapes={0: 4}),
tkw.WorkgroupConstraint(M=64, N=64, BLOCK_M=16, BLOCK_N=16)
]
# Compile to optimized GPU kernel
kernel = tkw.compile(vector_add, constraints)
# Create input tensors
a = torch.randn(1024, 512, dtype=torch.float16, device='cuda')
b = torch.randn(1024, 512, dtype=torch.float16, device='cuda')
c = torch.zeros_like(a)
# Execute the compiled kernel
kernel(a, b, c)
Softmax Function
A numerically stable softmax implementation with reduction operations. Demonstrates how Wave handles complex mathematical patterns.
import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw
# Define symbolic dimensions
M = tkl.sym.M
N = tkl.sym.N
@tkw.wave(constraints)
def softmax(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
"""Numerically stable softmax across the last dimension."""
val = tkw.read(a)
# Find maximum for numerical stability
row_max = tkw.max(val, dim=N)
row_max_bcast = tkw.broadcast(row_max, [M, N])
val -= row_max_bcast
# Exponentiate
val = tkw.exp(val)
# Sum and normalize
denominator = tkw.sum(val, dim=N)
denom_broadcast = tkw.broadcast(denominator, [M, N])
val = val / denom_broadcast
tkw.write(val, b)
import torch
# Define constraints for the kernel
constraints = [
tkw.HardwareConstraint(vector_shapes={0: 4}),
tkw.WorkgroupConstraint(M=32, N=128, BLOCK_M=16, BLOCK_N=16),
tkw.WaveConstraint(M=16, N=16)
]
# Compile the softmax kernel
kernel = tkw.compile(softmax, constraints)
# Batch of sequences to normalize
logits = torch.randn(32, 128, dtype=torch.float32, device='cuda')
probs = torch.zeros_like(logits)
# Apply softmax
kernel(logits, probs)
# Verify probabilities sum to 1
assert torch.allclose(probs.sum(dim=-1), torch.ones(32, device='cuda'))
Matrix Multiplication
High-performance matrix multiplication with automatic tiling and shared memory optimization. Shows Wave's strength in linear algebra.
import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw
# Define symbolic dimensions
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
@tkw.wave(constraints)
@tkw.iterate(K, init_args=[c_reg])
def gemm_kernel(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f32],
b: tkl.Memory[K, N, ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
c_reg: tkl.Register[M, N, tkl.f32],
):
"""High-performance matrix multiplication kernel."""
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# Matrix multiply-accumulate
c_reg = tkw.mma(a_reg, b_reg, c_reg)
# Write result
tkw.write(c_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
import torch
# Define constraints for the GEMM kernel
constraints = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1),
mfma_type=tkw.MfmaType.F32_32x32x8_F16),
tkw.WorkgroupConstraint(M=128, N=128, BLOCK_M=2, BLOCK_N=2),
tkw.TilingConstraint(K=32, BLOCK_K=1),
tkw.WaveConstraint(M=64, N=64)
]
# Compile the GEMM kernel
kernel = tkw.compile(gemm_kernel, constraints)
# Large matrices for demonstration
A = torch.randn(2048, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(1024, 2048, dtype=torch.float16, device='cuda')
C = torch.zeros(2048, 2048, dtype=torch.float32, device='cuda')
# Compute matrix product
kernel(A, B, C)
# Verify correctness
expected = torch.mm(A.float(), B.float())
assert torch.allclose(C, expected, atol=1e-3)
# Advanced: Configurable GEMM with different precisions
def create_gemm(M, N, K, dtype_a, dtype_b, dtype_c):
@tkw.wave(constraints)
@tkw.iterate(K, init_args=[c_reg])
def configurable_gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, dtype_a],
b: tkl.Memory[K, N, ADDRESS_SPACE, dtype_b],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, dtype_c],
c_reg: tkl.Register[M, N, dtype_c],
):
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
c_reg = tkw.mma(a_reg, b_reg, c_reg)
tkw.write(c_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
return configurable_gemm
# Mixed precision GEMM: FP16 inputs, FP32 accumulation
kernel_fp16 = create_gemm(M, N, K, tkl.f16, tkl.f16, tkl.f32)
Attention Mechanism
Complete scaled dot-product attention implementation. Demonstrates Wave's ability to express complex ML operations elegantly.
import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw
# Define symbolic dimensions
B = tkl.sym.B
S1 = tkl.sym.S1
S2 = tkl.sym.S2
H = tkl.sym.H
D = tkl.sym.D
@tkw.wave(constraints)
@tkw.iterate(S2, init_args=[acc])
def attention_kernel(
q: tkl.Memory[B, H, S1, D, ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
output: tkl.Memory[B, H, S1, D, ADDRESS_SPACE_0, tkl.f16],
acc: tkl.Register[B, H, S1, D, tkl.f32],
):
"""Scaled dot-product attention kernel."""
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# Compute Q @ K^T
qk = tkw.mma(q_reg, tkw.permute(k_reg, target_shape=[B, H, D, S2]))
# Apply softmax to attention scores
row_max = tkw.max(qk, dim=S2)
row_max_bcast = tkw.broadcast(row_max, [B, H, S1, S2])
qk -= row_max_bcast
attn_weights = tkw.exp2(qk)
row_sum = tkw.sum(attn_weights, dim=S2)
row_sum_bcast = tkw.broadcast(row_sum, [B, H, S1, S2])
attn_weights = attn_weights / row_sum_bcast
# Apply attention weights to values
acc = tkw.mma(attn_weights, v_reg, acc)
tkw.write(acc, output, elements_per_thread=STORE_ELEMS_PER_THREAD)
import torch
# Define constraints for attention kernel
constraints = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1),
mfma_type=tkw.MfmaType.F32_16x16x16_F16),
tkw.WorkgroupConstraint(B=1, H=2, S1=64, S2=64, D=64,
BLOCK_B=1, BLOCK_H=1, BLOCK_S1=64, BLOCK_S2=64, BLOCK_D=64),
tkw.TilingConstraint(S2=64, BLOCK_S2=1),
tkw.WaveConstraint(B=1, H=1, S1=16, S2=16, D=16)
]
# Compile the attention kernel
kernel = tkw.compile(attention_kernel, constraints)
# Transformer dimensions
batch_size, num_heads, seq_len, head_dim = 8, 12, 512, 64
# Create query, key, value tensors
Q = torch.randn(batch_size, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
K = torch.randn(batch_size, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
V = torch.randn(batch_size, num_heads, seq_len, head_dim,
dtype=torch.float16, device='cuda')
output = torch.zeros_like(Q)
# Compute attention
kernel(Q, K, V, output)
# Memory-efficient attention with causal masking
@tkw.wave(constraints)
@tkw.iterate(S2, init_args=[acc])
def causal_attention(
q: tkl.Memory[B, H, S1, D, ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
output: tkl.Memory[B, H, S1, D, ADDRESS_SPACE_0, tkl.f16],
acc: tkl.Register[B, H, S1, D, tkl.f32],
):
"""Causal attention with automatic masking."""
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# Compute Q @ K^T with causal masking
qk = tkw.mma(q_reg, tkw.permute(k_reg, target_shape=[B, H, D, S2]))
# Apply causal mask
mask = tkw.get_custom_user_vector(MASK, MASK_SHAPE)
qk = qk + mask
# Softmax with numerical stability
row_max = tkw.max(qk, dim=S2)
row_max_bcast = tkw.broadcast(row_max, [B, H, S1, S2])
qk = tkw.exp2(qk - row_max_bcast)
row_sum = tkw.sum(qk, dim=S2)
row_sum_bcast = tkw.broadcast(row_sum, [B, H, S1, S2])
attn_weights = qk / row_sum_bcast
acc = tkw.mma(attn_weights, v_reg, acc)
tkw.write(acc, output, elements_per_thread=STORE_ELEMS_PER_THREAD)
Ready to try these examples?
Install Wave and start experimenting with high-performance GPU kernels.
pip install wave-lang