GEMM Tutorial: High-Performance Matrix Multiplication

Learn how to implement a highly optimized matrix multiplication kernel using Wave Lang. This comprehensive tutorial covers constraints, memory management, and performance optimization techniques for achieving peak GPU performance.

Overview

This tutorial demonstrates implementing a high-performance matrix multiplication (GEMM) kernel using Wave Lang. The kernel computes C = A @ B.T with the following specifications:

  • A: M×K matrix in f16
  • B: N×K matrix in f16
  • C: M×N matrix in f32

What You'll Learn

By the end of this tutorial, you'll understand how to:

  • Define symbolic dimensions and constraints in Wave Lang
  • Implement efficient matrix multiplication with mixed precision
  • Optimize memory access patterns and hardware utilization
  • Use Wave Lang's iteration constructs for reduction operations

Implementation

Imports and Symbolic Dimensions

First, we import the necessary Wave Lang modules and define our symbolic dimensions. Symbolic dimensions allow Wave Lang to generate kernels that work with different matrix sizes.

Python

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
import torch

# Define symbolic dimensions
M = tkl.sym.M  # Rows of A and C
N = tkl.sym.N  # Rows of B and columns of C
K = tkl.sym.K  # Columns of A and B

# Define workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

Kernel Constraints

Wave Lang separates computation logic from scheduling decisions through constraints. These constraints define how the computation is tiled, scheduled, and mapped to hardware.

Python
constraints = [
    # Workgroup constraints define how work is distributed across GPU blocks
    tkw.WorkgroupConstraint(M=BLOCK_M, N=BLOCK_N, BLOCK_M=2, BLOCK_N=2),

    # Tiling constraint for the reduction dimension
    tkw.TilingConstraint(K=BLOCK_K, BLOCK_K=1),

    # Wave constraints control intra-workgroup parallelism
    tkw.WaveConstraint(M=BLOCK_M//2, N=BLOCK_N//2),

    # Hardware constraints specify target architecture features
    tkw.HardwareConstraint(
        threads_per_wave=64,
        waves_per_block=(2, 2, 1),
        mfma_type=tkw.MfmaType.F32_16x16x16_F16  # Mixed precision MMA
    )
]

GEMM Kernel Definition

The core GEMM kernel uses Wave Lang's @tkw.iterate decorator to handle the reduction over the K dimension. The kernel logic focuses purely on the mathematical computation.

Python

@tkw.wave(constraints)
def gemm_kernel(
    a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
    b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
    c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
    acc: tkl.Register[M, N, tkl.f32],
):
    """
    High-performance GEMM kernel with mixed precision.

    Computes C = A @ B.T where:
    - A and B are f16 inputs
    - C is f32 output for numerical stability
    """
    # Load input tiles
    a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
    b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)

    # Perform matrix multiply-accumulate
    # Automatically promotes f16 inputs to f32 for accumulation
    acc = tkw.mma(a_reg, b_reg, acc)

    # Write result back to global memory
    tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

Compilation and Usage

Once the kernel is defined, we can compile it with specific parameters and use it with PyTorch tensors. Wave Lang automatically generates optimized GPU code based on our constraints.

Python
# Compile the kernel with specific tile sizes
kernel = tkw.compile(
    gemm_kernel,
    constraints,
    # Runtime parameters
    dynamic_symbols={
        M: 2048, N: 2048, K: 1024,
        BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32,
        LOAD_ELEMS_PER_THREAD: 4,
        STORE_ELEMS_PER_THREAD: 4
    }
)

# Create input tensors
A = torch.randn(2048, 1024, dtype=torch.float16, device='cuda')
B = torch.randn(2048, 1024, dtype=torch.float16, device='cuda')  # Will be transposed
C = torch.zeros(2048, 2048, dtype=torch.float32, device='cuda')

# Execute the compiled kernel
kernel(A, B, C)

# Verify correctness against PyTorch
expected = torch.mm(A.float(), B.float().T)
assert torch.allclose(C, expected, atol=1e-3)

Performance Optimization Techniques

Mixed Precision Computation

This GEMM implementation uses mixed precision: f16 inputs with f32 accumulation. This approach provides the memory bandwidth benefits of f16 while maintaining numerical stability through f32 accumulation.

Memory Access Optimization

Wave Lang automatically optimizes memory access patterns based on the constraints:

  • Coalesced Access: Memory reads are aligned for optimal bandwidth
  • Shared Memory: Tiles are cached in fast shared memory
  • Register Blocking: Data reuse is maximized within registers

Hardware Utilization

The HardwareConstraint specifies the use of mixed-precision Matrix-Multiply-Accumulate (MMA) instructions, which can achieve peak theoretical performance on modern GPUs.

Advanced Features

Configurable Precision

You can easily modify the kernel for different precision combinations by changing the type annotations:

Dynamic Shapes

Wave Lang's symbolic dimensions enable the same kernel to work efficiently with different matrix sizes without recompilation, making it ideal for dynamic neural network workloads.

Conclusion

This tutorial demonstrated how Wave Lang's separation of concerns makes it easy to implement high-performance GEMM kernels. By separating the mathematical logic from scheduling decisions, you can:

  • Focus on the algorithm rather than low-level GPU programming details
  • Easily experiment with different optimization strategies
  • Achieve portable performance across different GPU architectures
  • Maintain readable and maintainable code

The same principles apply to other linear algebra operations like convolutions, attention mechanisms, and custom ML operators. Wave Lang makes GPU programming both powerful and enjoyable!