Examples#

This page shows the example scripts included with tritonBLAS.

Basic Matrix Multiplication#

The example_matmul.py script demonstrates basic matrix multiplication using the simple API:

examples/example_matmul.py#
import torch
import triton
import tritonblas
import argparse


def example_matmul(m, n, k):
    # Allocate Tensors
    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(n, k, device="cuda", dtype=torch.float16).T
    C = torch.zeros((m, n), device="cuda", dtype=torch.float16)

    # Run TritonBLAS matmul
    tritonblas.matmul(A, B, C)

    # Print result
    print(C)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Example TritonBLAS matrix multiplication with CLI parameters for m, n, k."
    )
    parser.add_argument(
        "--m",
        type=int,
        default=8192,
        help="Number of rows in matrix A and C (default: 8192)",
    )
    parser.add_argument(
        "--n",
        type=int,
        default=8192,
        help="Number of columns in matrix B (after transpose) and C (default: 8192)",
    )
    parser.add_argument(
        "--k",
        type=int,
        default=8192,
        help="Number of columns in matrix A and rows in matrix B (default: 8192)",
    )
    args = parser.parse_args()
    example_matmul(args.m, args.n, args.k)

Usage:

cd examples
python3 example_matmul.py
python3 example_matmul.py --m 4096 --n 4096 --k 4096

Matrix Multiplication with Selector#

The example_matmul_lt.py script demonstrates using the optimized API with a pre-computed selector:

examples/example_matmul_lt.py#
import torch
import triton
import tritonblas
import argparse


def example_matmul(m, n, k):
    # Allocate Tensors
    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(n, k, device="cuda", dtype=torch.float16).T
    C = torch.zeros((m, n), device="cuda", dtype=torch.float16)

    # Run TritonBLAS matmul
    selector = tritonblas.OrigamiMatmulSelector(m, n, k, A.dtype, B.dtype, C.dtype, A.device)
    tritonblas.matmul_lt(A, B, C, selector)

    # Print result
    print(C)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Example TritonBLAS matrix multiplication with CLI parameters for m, n, k."
    )
    parser.add_argument(
        "--m",
        type=int,
        default=8192,
        help="Number of rows in matrix A and C (default: 8192)",
    )
    parser.add_argument(
        "--n",
        type=int,
        default=8192,
        help="Number of columns in matrix B (after transpose) and C (default: 8192)",
    )
    parser.add_argument(
        "--k",
        type=int,
        default=8192,
        help="Number of columns in matrix A and rows in matrix B (default: 8192)",
    )
    args = parser.parse_args()
    example_matmul(args.m, args.n, args.k)

Usage:

cd examples
python3 example_matmul_lt.py
python3 example_matmul_lt.py --m 4096 --n 4096 --k 4096