Overview
This tutorial demonstrates implementing a high-performance matrix multiplication (GEMM) kernel using Wave Lang.
The kernel computes C = A @ B.T
with the following specifications:
- A: M×K matrix in f16
- B: N×K matrix in f16
- C: M×N matrix in f32
What You'll Learn
By the end of this tutorial, you'll understand how to:
- Define symbolic dimensions and constraints in Wave Lang
- Implement efficient matrix multiplication with mixed precision
- Optimize memory access patterns and hardware utilization
- Use Wave Lang's iteration constructs for reduction operations
Implementation
Imports and Symbolic Dimensions
First, we import the necessary Wave Lang modules and define our symbolic dimensions. Symbolic dimensions allow Wave Lang to generate kernels that work with different matrix sizes.
import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
import torch
# Define symbolic dimensions
M = tkl.sym.M # Rows of A and C
N = tkl.sym.N # Rows of B and columns of C
K = tkl.sym.K # Columns of A and B
# Define workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
Kernel Constraints
Wave Lang separates computation logic from scheduling decisions through constraints. These constraints define how the computation is tiled, scheduled, and mapped to hardware.
constraints = [
# Workgroup constraints define how work is distributed across GPU blocks
tkw.WorkgroupConstraint(M=BLOCK_M, N=BLOCK_N, BLOCK_M=2, BLOCK_N=2),
# Tiling constraint for the reduction dimension
tkw.TilingConstraint(K=BLOCK_K, BLOCK_K=1),
# Wave constraints control intra-workgroup parallelism
tkw.WaveConstraint(M=BLOCK_M//2, N=BLOCK_N//2),
# Hardware constraints specify target architecture features
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mfma_type=tkw.MfmaType.F32_16x16x16_F16 # Mixed precision MMA
)
]
GEMM Kernel Definition
The core GEMM kernel uses Wave Lang's @tkw.iterate
decorator to handle the reduction
over the K dimension. The kernel logic focuses purely on the mathematical computation.
@tkw.wave(constraints)
def gemm_kernel(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
acc: tkl.Register[M, N, tkl.f32],
):
"""
High-performance GEMM kernel with mixed precision.
Computes C = A @ B.T where:
- A and B are f16 inputs
- C is f32 output for numerical stability
"""
# Load input tiles
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# Perform matrix multiply-accumulate
# Automatically promotes f16 inputs to f32 for accumulation
acc = tkw.mma(a_reg, b_reg, acc)
# Write result back to global memory
tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
Compilation and Usage
Once the kernel is defined, we can compile it with specific parameters and use it with PyTorch tensors. Wave Lang automatically generates optimized GPU code based on our constraints.
# Compile the kernel with specific tile sizes
kernel = tkw.compile(
gemm_kernel,
constraints,
# Runtime parameters
dynamic_symbols={
M: 2048, N: 2048, K: 1024,
BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4
}
)
# Create input tensors
A = torch.randn(2048, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(2048, 1024, dtype=torch.float16, device='cuda') # Will be transposed
C = torch.zeros(2048, 2048, dtype=torch.float32, device='cuda')
# Execute the compiled kernel
kernel(A, B, C)
# Verify correctness against PyTorch
expected = torch.mm(A.float(), B.float().T)
assert torch.allclose(C, expected, atol=1e-3)
Performance Optimization Techniques
Mixed Precision Computation
This GEMM implementation uses mixed precision: f16 inputs with f32 accumulation. This approach provides the memory bandwidth benefits of f16 while maintaining numerical stability through f32 accumulation.
Memory Access Optimization
Wave Lang automatically optimizes memory access patterns based on the constraints:
- Coalesced Access: Memory reads are aligned for optimal bandwidth
- Shared Memory: Tiles are cached in fast shared memory
- Register Blocking: Data reuse is maximized within registers
Hardware Utilization
The HardwareConstraint
specifies the use of mixed-precision Matrix-Multiply-Accumulate (MMA)
instructions, which can achieve peak theoretical performance on modern GPUs.
Advanced Features
Configurable Precision
You can easily modify the kernel for different precision combinations by changing the type annotations:
Dynamic Shapes
Wave Lang's symbolic dimensions enable the same kernel to work efficiently with different matrix sizes without recompilation, making it ideal for dynamic neural network workloads.
Conclusion
This tutorial demonstrated how Wave Lang's separation of concerns makes it easy to implement high-performance GEMM kernels. By separating the mathematical logic from scheduling decisions, you can:
- Focus on the algorithm rather than low-level GPU programming details
- Easily experiment with different optimization strategies
- Achieve portable performance across different GPU architectures
- Maintain readable and maintainable code
The same principles apply to other linear algebra operations like convolutions, attention mechanisms, and custom ML operators. Wave Lang makes GPU programming both powerful and enjoyable!