Iris Class API#

Factory Function#

Prefer using the convenience factory over calling the constructor directly:

iris(heap_size=1073741824)[source]#

Create and return an Iris instance with the specified heap size.

Parameters:

heap_size (int) – Size of the heap in bytes. Defaults to 1GB.

Returns:

An initialized Iris instance.

Return type:

Iris

Example

>>> import iris
>>> iris_ctx = iris.iris(2**30)  # 1GB heap
>>> tensor = iris_ctx.zeros(1024, 1024)

Core Methods#

Iris.get_heap_bases()[source]#

Return the tensor of symmetric heap base addresses for all ranks.

Returns:

A 1D tensor of uint64 heap base addresses of size num_ranks on the Iris device. Pass this to device-side Triton kernels that require heap translation.

Return type:

torch.Tensor

Example

>>> ctx = iris.iris(1 << 20)
>>> heap_bases = ctx.get_heap_bases()
>>> print(heap_bases.shape)  # torch.Size([num_ranks])
Iris.barrier(stream=None)[source]#

Synchronize all ranks and their CUDA devices.

This first calls torch.cuda.synchronize() or stream.synchronize() to ensure the local GPU has finished all queued work, then performs a global distributed barrier so that all ranks reach the same point before proceeding. :param stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync).

Example

>>> ctx = iris.iris(1 << 20)
>>> ctx.barrier()  # Synchronize all ranks
Iris.get_device()[source]#

Get the underlying device where the Iris symmetric heap resides.

Returns:

The CUDA device of Iris-managed memory.

Return type:

torch.device

Example

>>> ctx = iris.iris(1 << 20)
>>> device = ctx.get_device()
>>> print(device)  # cuda:0
Iris.get_cu_count()[source]#

Get the number of compute units (CUs) for the current GPU.

Returns:

Number of compute units on this rank’s GPU.

Return type:

int

Example

>>> ctx = iris.iris(1 << 20)
>>> cu_count = ctx.get_cu_count()
>>> print(f"GPU has {cu_count} CUs")  # GPU has 304 CUs
Iris.get_rank()[source]#

Get this process’s rank id in the distributed communicator.

Returns:

Zero-based rank id of the current process.

Return type:

int

Example

>>> ctx = iris.iris(1 << 20)
>>> rank = ctx.get_rank()
>>> print(f"This is rank {rank}")  # This is rank 0
Iris.get_num_ranks()[source]#

Get the total number of ranks in the distributed communicator.

Returns:

World size (number of ranks).

Return type:

int

Example

>>> ctx = iris.iris(1 << 20)
>>> num_ranks = ctx.get_num_ranks()
>>> print(f"Total ranks: {num_ranks}")  # Total ranks: 1

Logging Helpers#

Use Iris-aware logging that automatically annotates each message with the current rank and world size. This is helpful when debugging multi-rank programs.

set_logger_level(level)[source]#

Set the logging level for the iris logger.

Parameters:

level – Logging level (iris.DEBUG, iris.INFO, iris.WARNING, iris.ERROR)

Example

>>> ctx = iris.iris()
>>> iris.set_logger_level(iris.DEBUG)
>>> ctx.debug("This will now be visible")  # [Iris] [0/1] This will now be visible
Iris.debug(message)[source]#

Log a debug message with rank information.

Parameters:

message (str) – Human-readable message to log at debug level.

Notes

The log record is enriched with iris_rank and iris_num_ranks so formatters can display the originating rank and world size.

Example

>>> ctx = iris.iris()
>>> iris.set_logger_level(iris.DEBUG)
>>> ctx.debug("Allocating buffers")  # [Iris] [0/1] Allocating buffers
Iris.info(message)[source]#

Log an info message with rank information.

Parameters:

message (str) – Human-readable message to log at info level.

Example

>>> ctx = iris.iris()
>>> ctx.info("Starting iteration 0")  # [Iris] [0/1] Starting iteration 0
Iris.warning(message)[source]#

Log a warning message with rank information.

Parameters:

message (str) – Human-readable message to log at warning level.

Example

>>> ctx = iris.iris()
>>> ctx.warning("Memory usage is high")  # [Iris] [0/1] Memory usage is high
Iris.error(message)[source]#

Log an error message with rank information.

Parameters:

message (str) – Human-readable message to log at error level.

Example

>>> ctx = iris.iris()
>>> ctx.error("Failed to allocate memory")  # [Iris] [0/1] Failed to allocate memory

Utility Functions#

do_bench(fn, barrier_fn=<function <lambda>>, preamble_fn=<function <lambda>>, n_warmup=25, n_repeat=100, quantiles=None, return_mode='mean')[source]#

Benchmark a function by timing its execution.

Parameters:
  • fn (callable) – Function to benchmark.

  • barrier_fn (callable, optional) – Function to call for synchronization. Default: no-op.

  • preamble_fn (callable, optional) – Function to call before each execution. Default: no-op.

  • n_warmup (int, optional) – Number of warmup iterations. Default: 25.

  • n_repeat (int, optional) – Number of timing iterations. Default: 100.

  • quantiles (list, optional) – Quantiles to return instead of summary statistic. Default: None.

  • return_mode (str, optional) – Summary statistic to return (“mean”, “min”, “max”, “median”, “all”). Default: “mean”.

Returns:

Timing result(s) in milliseconds.

Return type:

float or list

Example

>>> import iris
>>> iris_ctx = iris.iris(1 << 20)
>>> def test_fn():
>>>     tensor = iris_ctx.zeros(1000, 1000)
>>> time_ms = iris.do_bench(test_fn, barrier_fn=iris_ctx.barrier)

Broadcast Helper#

Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper.

Iris.broadcast(value, source_rank)[source]#

Broadcast a Python scalar or small picklable object from one rank to all ranks.

Parameters:
  • value (Any) – The value to broadcast. Only the source_rank value is used; other ranks should pass a placeholder (e.g., None).

  • source_rank (int) – Rank id that holds the authoritative value.

Returns:

The value broadcast to all ranks.

Return type:

Any

Example

>>> ctx = iris.iris()
>>> value = 42 if ctx.cur_rank == 0 else None
>>> value = ctx.broadcast(value, source_rank=0)  # All ranks get 42