Quick Start Guide#
Your First Matrix Multiplication#
Here’s a simple example to get you started:
import torch
import tritonblas
# Create input matrices
A = torch.randn(4096, 4096, dtype=torch.float16, device='cuda')
B = torch.randn(4096, 4096, dtype=torch.float16, device='cuda')
C = torch.zeros(4096, 4096, dtype=torch.float16, device='cuda')
# Perform matrix multiplication
tritonblas.matmul(A, B, C)
print(f"Result shape: {C.shape}")
Using the Peak Performance API#
For optimal performance, use the two-step API that separates configuration from execution:
import torch
import tritonblas
# Step 1: Get optimal configuration for your matrix dimensions
m, n, k = 4096, 4096, 4096
selector = tritonblas.MatmulHeuristicResult(
m, n, k,
a_dtype=torch.float16,
b_dtype=torch.float16,
c_dtype=torch.float16
)
# Step 2: Create input matrices
A = torch.randn(m, k, dtype=torch.float16, device='cuda')
B = torch.randn(k, n, dtype=torch.float16, device='cuda')
C = torch.zeros(m, n, dtype=torch.float16, device='cuda')
# Step 3: Perform matrix multiplication with optimal config
tritonblas.matmul_lt(A, B, C, selector)
print(f"Result shape: {C.shape}")