Wave Lang Examples

See how Wave Lang makes GPU programming simple with these practical examples. From basic operations to complex ML kernels.

Vector Addition

The classic "Hello World" of GPU programming. Add two vectors element-wise with automatic parallelization and memory optimization.

import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw

# Define symbolic dimensions
M = tkl.sym.M
N = tkl.sym.N

@tkw.wave(constraints)
def vector_add(
    a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
    b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
    c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
    """Add two vectors element-wise."""
    lhs = tkw.read(a)
    rhs = tkw.read(b)
    res = lhs + rhs
    tkw.write(res, c)
import torch

# Define constraints for the kernel
constraints = [
    tkw.HardwareConstraint(vector_shapes={0: 4}),
    tkw.WorkgroupConstraint(M=64, N=64, BLOCK_M=16, BLOCK_N=16)
]

# Compile to optimized GPU kernel
kernel = tkw.compile(vector_add, constraints)

# Create input tensors
a = torch.randn(1024, 512, dtype=torch.float16, device='cuda')
b = torch.randn(1024, 512, dtype=torch.float16, device='cuda')
c = torch.zeros_like(a)

# Execute the compiled kernel
kernel(a, b, c)

Softmax Function

A numerically stable softmax implementation with reduction operations. Demonstrates how Wave handles complex mathematical patterns.

import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw

# Define symbolic dimensions
M = tkl.sym.M
N = tkl.sym.N

@tkw.wave(constraints)
def softmax(
    a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
    b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
    """Numerically stable softmax across the last dimension."""
    val = tkw.read(a)

    # Find maximum for numerical stability
    row_max = tkw.max(val, dim=N)
    row_max_bcast = tkw.broadcast(row_max, [M, N])
    val -= row_max_bcast

    # Exponentiate
    val = tkw.exp(val)

    # Sum and normalize
    denominator = tkw.sum(val, dim=N)
    denom_broadcast = tkw.broadcast(denominator, [M, N])
    val = val / denom_broadcast

    tkw.write(val, b)
import torch

# Define constraints for the kernel
constraints = [
    tkw.HardwareConstraint(vector_shapes={0: 4}),
    tkw.WorkgroupConstraint(M=32, N=128, BLOCK_M=16, BLOCK_N=16),
    tkw.WaveConstraint(M=16, N=16)
]

# Compile the softmax kernel
kernel = tkw.compile(softmax, constraints)

# Batch of sequences to normalize
logits = torch.randn(32, 128, dtype=torch.float32, device='cuda')
probs = torch.zeros_like(logits)

# Apply softmax
kernel(logits, probs)

# Verify probabilities sum to 1
assert torch.allclose(probs.sum(dim=-1), torch.ones(32, device='cuda'))

Matrix Multiplication

High-performance matrix multiplication with automatic tiling and shared memory optimization. Shows Wave's strength in linear algebra.

import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw

# Define symbolic dimensions
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K

@tkw.wave(constraints)
@tkw.iterate(K, init_args=[c_reg])
def gemm_kernel(
    a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f32],
    b: tkl.Memory[K, N, ADDRESS_SPACE, tkl.f32],
    c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
    c_reg: tkl.Register[M, N, tkl.f32],
):
    """High-performance matrix multiplication kernel."""
    a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
    b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)

    # Matrix multiply-accumulate
    c_reg = tkw.mma(a_reg, b_reg, c_reg)

    # Write result
    tkw.write(c_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
import torch

# Define constraints for the GEMM kernel
constraints = [
    tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1),
                           mfma_type=tkw.MfmaType.F32_32x32x8_F16),
    tkw.WorkgroupConstraint(M=128, N=128, BLOCK_M=2, BLOCK_N=2),
    tkw.TilingConstraint(K=32, BLOCK_K=1),
    tkw.WaveConstraint(M=64, N=64)
]

# Compile the GEMM kernel
kernel = tkw.compile(gemm_kernel, constraints)

# Large matrices for demonstration
A = torch.randn(2048, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(1024, 2048, dtype=torch.float16, device='cuda')
C = torch.zeros(2048, 2048, dtype=torch.float32, device='cuda')

# Compute matrix product
kernel(A, B, C)

# Verify correctness
expected = torch.mm(A.float(), B.float())
assert torch.allclose(C, expected, atol=1e-3)
# Advanced: Configurable GEMM with different precisions
def create_gemm(M, N, K, dtype_a, dtype_b, dtype_c):
    @tkw.wave(constraints)
    @tkw.iterate(K, init_args=[c_reg])
    def configurable_gemm(
        a: tkl.Memory[M, K, ADDRESS_SPACE, dtype_a],
        b: tkl.Memory[K, N, ADDRESS_SPACE, dtype_b],
        c: tkl.Memory[M, N, ADDRESS_SPACE_0, dtype_c],
        c_reg: tkl.Register[M, N, dtype_c],
    ):
        a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
        b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
        c_reg = tkw.mma(a_reg, b_reg, c_reg)
        tkw.write(c_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

    return configurable_gemm

# Mixed precision GEMM: FP16 inputs, FP32 accumulation
kernel_fp16 = create_gemm(M, N, K, tkl.f16, tkl.f16, tkl.f32)

Attention Mechanism

Complete scaled dot-product attention implementation. Demonstrates Wave's ability to express complex ML operations elegantly.

import wave_lang.kernel.lang as tkl as tkl
import wave_lang.kernel.wave as tkw

# Define symbolic dimensions
B = tkl.sym.B
S1 = tkl.sym.S1
S2 = tkl.sym.S2
H = tkl.sym.H
D = tkl.sym.D

@tkw.wave(constraints)
@tkw.iterate(S2, init_args=[acc])
def attention_kernel(
    q: tkl.Memory[B, H, S1, D, ADDRESS_SPACE, tkl.f16],
    k: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
    v: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
    output: tkl.Memory[B, H, S1, D, ADDRESS_SPACE_0, tkl.f16],
    acc: tkl.Register[B, H, S1, D, tkl.f32],
):
    """Scaled dot-product attention kernel."""
    q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
    k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
    v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)

    # Compute Q @ K^T
    qk = tkw.mma(q_reg, tkw.permute(k_reg, target_shape=[B, H, D, S2]))

    # Apply softmax to attention scores
    row_max = tkw.max(qk, dim=S2)
    row_max_bcast = tkw.broadcast(row_max, [B, H, S1, S2])
    qk -= row_max_bcast
    attn_weights = tkw.exp2(qk)
    row_sum = tkw.sum(attn_weights, dim=S2)
    row_sum_bcast = tkw.broadcast(row_sum, [B, H, S1, S2])
    attn_weights = attn_weights / row_sum_bcast

    # Apply attention weights to values
    acc = tkw.mma(attn_weights, v_reg, acc)

    tkw.write(acc, output, elements_per_thread=STORE_ELEMS_PER_THREAD)
import torch

# Define constraints for attention kernel
constraints = [
    tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1),
                           mfma_type=tkw.MfmaType.F32_16x16x16_F16),
    tkw.WorkgroupConstraint(B=1, H=2, S1=64, S2=64, D=64,
                            BLOCK_B=1, BLOCK_H=1, BLOCK_S1=64, BLOCK_S2=64, BLOCK_D=64),
    tkw.TilingConstraint(S2=64, BLOCK_S2=1),
    tkw.WaveConstraint(B=1, H=1, S1=16, S2=16, D=16)
]

# Compile the attention kernel
kernel = tkw.compile(attention_kernel, constraints)

# Transformer dimensions
batch_size, num_heads, seq_len, head_dim = 8, 12, 512, 64

# Create query, key, value tensors
Q = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
K = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
V = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
output = torch.zeros_like(Q)

# Compute attention
kernel(Q, K, V, output)
# Memory-efficient attention with causal masking
@tkw.wave(constraints)
@tkw.iterate(S2, init_args=[acc])
def causal_attention(
    q: tkl.Memory[B, H, S1, D, ADDRESS_SPACE, tkl.f16],
    k: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
    v: tkl.Memory[B, H, S2, D, ADDRESS_SPACE, tkl.f16],
    output: tkl.Memory[B, H, S1, D, ADDRESS_SPACE_0, tkl.f16],
    acc: tkl.Register[B, H, S1, D, tkl.f32],
):
    """Causal attention with automatic masking."""
    q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
    k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
    v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)

    # Compute Q @ K^T with causal masking
    qk = tkw.mma(q_reg, tkw.permute(k_reg, target_shape=[B, H, D, S2]))

    # Apply causal mask
    mask = tkw.get_custom_user_vector(MASK, MASK_SHAPE)
    qk = qk + mask

    # Softmax with numerical stability
    row_max = tkw.max(qk, dim=S2)
    row_max_bcast = tkw.broadcast(row_max, [B, H, S1, S2])
    qk = tkw.exp2(qk - row_max_bcast)
    row_sum = tkw.sum(qk, dim=S2)
    row_sum_bcast = tkw.broadcast(row_sum, [B, H, S1, S2])
    attn_weights = qk / row_sum_bcast

    acc = tkw.mma(attn_weights, v_reg, acc)
    tkw.write(acc, output, elements_per_thread=STORE_ELEMS_PER_THREAD)

Ready to try these examples?

Install Wave and start experimenting with high-performance GPU kernels.

pip install wave-lang