Iris Class API#

Factory Function#

Prefer using the convenience factory over calling the constructor directly:

iris(heap_size=1073741824)[source]#

Create and return an Iris instance with the specified heap size.

Parameters:

heap_size (int) – Size of the heap in bytes. Defaults to 1GB.

Returns:

An initialized Iris instance.

Return type:

Iris

Core Methods#

Iris.get_heap_bases()[source]#

Return the tensor of symmetric heap base addresses for all ranks.

Returns:

A 1D tensor of uint64 heap base addresses of size num_ranks on the Iris device. Pass this to device-side Triton kernels that require heap translation.

Return type:

torch.Tensor

Iris.barrier(stream=None)[source]#

Synchronize all ranks and their CUDA devices.

This first calls torch.cuda.synchronize() or stream.synchronize() to ensure the local GPU has finished all queued work, then performs a global MPI barrier so that all ranks reach the same point before proceeding. :param stream: If stream is given: wait only for that stream before MPI_Barrier. If stream is None: legacy behavior (device-wide sync).

Iris.get_device()[source]#

Get the underlying device where the Iris symmetric heap resides.

Returns:

The CUDA device of Iris-managed memory.

Return type:

torch.device

Iris.get_cu_count()[source]#

Get the number of compute units (CUs) for the current GPU.

Returns:

Number of compute units on this rank’s GPU.

Return type:

int

Iris.get_rank()[source]#

Get this process’s rank id in the MPI communicator.

Returns:

Zero-based rank id of the current process.

Return type:

int

Iris.get_num_ranks()[source]#

Get the total number of ranks in the MPI communicator.

Returns:

World size (number of ranks).

Return type:

int

Logging Helpers#

Use Iris-aware logging that automatically annotates each message with the current rank and world size. This is helpful when debugging multi-rank programs.

Iris.debug(message)[source]#

Log a debug message with rank information.

Parameters:

message (str) – Human-readable message to log at debug level.

Notes

The log record is enriched with iris_rank and iris_num_ranks so formatters can display the originating rank and world size.

Example

>>> iris_ctx.debug("Allocating buffers")
Iris.info(message)[source]#

Log an info message with rank information.

Parameters:

message (str) – Human-readable message to log at info level.

Example

>>> iris_ctx.info("Starting iteration 0")
Iris.warning(message)[source]#

Log a warning message with rank information.

Parameters:

message (str) – Human-readable message to log at warning level.

Iris.error(message)[source]#

Log an error message with rank information.

Parameters:

message (str) – Human-readable message to log at error level.

Broadcast Helper#

Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal MPI helper.

Iris.broadcast(value, source_rank)[source]#

Broadcast a Python scalar or small picklable object from one rank to all ranks.

Parameters:
  • value (Any) – The value to broadcast. Only the source_rank value is used; other ranks should pass a placeholder (e.g., None).

  • source_rank (int) – Rank id that holds the authoritative value.

Returns:

The value broadcast to all ranks.

Return type:

Any

Example

>>> value = 42 if iris_ctx.get_rank() == 0 else None
>>> value = iris_ctx.broadcast(value, source_rank=0)