Programming Model#
Iris is an open-source triton-based framework for Remote Memory Access (RMA)[1] operations written in only a few 100 lines of code. Iris provides SHMEM-like APIs within Triton for Multi-GPU programming.


Designed by Experts, Built for Scale
Written from scratch by GPU and distributed computing experts
Minimal dependencies: only Triton, PyTorch, HIP runtime and mpi4py (for initialization)
No external frameworks or heavyweight runtimes beyond core stack
Clean Abstractions
Full Symmetric Heap implementation in Python
Pythonic PyTorch-like host APIs for tensor allocation and construction
Pythonic Triton-style device APIs for load, store, and atomic ops
Communication + Computation
Device-side collective ops: broadcast, scatter, reduce, etc.
Lock variants for communication and computation overlap
Fine-grained GEMM + communication overlap via workgroup specialization
Scalable by Design
Full scale-up (multi-GPU node) support
Scale-out (multi-node) in progress
Simple load
& store
APIs#
@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
"""
Loads a value from the specified rank's memory location.
This function performs a memory read operation by translating the pointer
from the from_rank's address space to the to_rank's address space and loading
data from the target memory location. If the from_rank and to_rank are the same,
this function performs a local load operation.
Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
Returns:
Block: The loaded value from the target memory location.
"""
@triton.jit
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
"""
Writes data to the specified rank's memory location.
This function performs a memory write operation by translating the pointer
from the from_rank's address space to the to_rank's address space and storing
the provided data to the target memory location. If the from_rank and to_rank are the same,
this function performs a local store operation.
Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
value (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.
Returns:
None
"""
iris
Symmetric Heap & Implementation#


Symmetric Heap is a Partitioned Global Address Space (PGAS) abstraction Key idea is that you can know the remote address of any symmetric variable with two offsets:
Offset of target Process’ heap base in its virtual address space
Offset of the variable within the symmetric heap
Allocation routine for symmetric variables must be collective or offset must be known. Must all_gather the base heap addresses across all processes.
@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
result = tl.load(translated_ptr, mask=mask)
return result
@triton.jit
def __translate(ptr, from_rank, to_rank, heap_bases):
from_base = tl.load(heap_bases + from_rank)
to_base = tl.load(heap_bases + to_rank)
ptr_int = tl.cast(ptr, tl.uint64)
offset = ptr_int - from_base
to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8))
translated_ptr_byte = to_base_byte + offset
translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype)
return translated_ptr