Iris Class API#
Factory Function#
Prefer using the convenience factory over calling the constructor directly:
- iris(heap_size=1073741824)[source]#
Create and return an Iris instance with the specified heap size.
- Parameters:
heap_size (int) – Size of the heap in bytes. Defaults to 1GB.
- Returns:
An initialized Iris instance.
- Return type:
Iris
Example
>>> import iris >>> iris_ctx = iris.iris(2**30) # 1GB heap >>> tensor = iris_ctx.zeros(1024, 1024)
Core Methods#
- Iris.get_heap_bases()[source]#
Return the tensor of symmetric heap base addresses for all ranks.
- Returns:
A 1D tensor of
uint64
heap base addresses of sizenum_ranks
on the Iris device. Pass this to device-side Triton kernels that require heap translation.- Return type:
torch.Tensor
Example
>>> ctx = iris.iris(1 << 20) >>> heap_bases = ctx.get_heap_bases() >>> print(heap_bases.shape) # torch.Size([num_ranks])
- Iris.barrier(stream=None)[source]#
Synchronize all ranks and their CUDA devices.
This first calls
torch.cuda.synchronize()
orstream.synchronize()
to ensure the local GPU has finished all queued work, then performs a global distributed barrier so that all ranks reach the same point before proceeding. :param stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync).Example
>>> ctx = iris.iris(1 << 20) >>> ctx.barrier() # Synchronize all ranks
- Iris.get_device()[source]#
Get the underlying device where the Iris symmetric heap resides.
- Returns:
The CUDA device of Iris-managed memory.
- Return type:
torch.device
Example
>>> ctx = iris.iris(1 << 20) >>> device = ctx.get_device() >>> print(device) # cuda:0
- Iris.get_cu_count()[source]#
Get the number of compute units (CUs) for the current GPU.
- Returns:
Number of compute units on this rank’s GPU.
- Return type:
Example
>>> ctx = iris.iris(1 << 20) >>> cu_count = ctx.get_cu_count() >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs
Logging Helpers#
Use Iris-aware logging that automatically annotates each message with the current rank and world size. This is helpful when debugging multi-rank programs.
- set_logger_level(level)[source]#
Set the logging level for the iris logger.
- Parameters:
level – Logging level (iris.DEBUG, iris.INFO, iris.WARNING, iris.ERROR)
Example
>>> ctx = iris.iris() >>> iris.set_logger_level(iris.DEBUG) >>> ctx.debug("This will now be visible") # [Iris] [0/1] This will now be visible
- Iris.debug(message)[source]#
Log a debug message with rank information.
- Parameters:
message (str) – Human-readable message to log at debug level.
Notes
The log record is enriched with
iris_rank
andiris_num_ranks
so formatters can display the originating rank and world size.Example
>>> ctx = iris.iris() >>> iris.set_logger_level(iris.DEBUG) >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers
- Iris.info(message)[source]#
Log an info message with rank information.
- Parameters:
message (str) – Human-readable message to log at info level.
Example
>>> ctx = iris.iris() >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0
Utility Functions#
- do_bench(fn, barrier_fn=<function <lambda>>, preamble_fn=<function <lambda>>, n_warmup=25, n_repeat=100, quantiles=None, return_mode='mean')[source]#
Benchmark a function by timing its execution.
- Parameters:
fn (callable) – Function to benchmark.
barrier_fn (callable, optional) – Function to call for synchronization. Default: no-op.
preamble_fn (callable, optional) – Function to call before each execution. Default: no-op.
n_warmup (int, optional) – Number of warmup iterations. Default: 25.
n_repeat (int, optional) – Number of timing iterations. Default: 100.
quantiles (list, optional) – Quantiles to return instead of summary statistic. Default: None.
return_mode (str, optional) – Summary statistic to return (“mean”, “min”, “max”, “median”, “all”). Default: “mean”.
- Returns:
Timing result(s) in milliseconds.
- Return type:
Example
>>> import iris >>> iris_ctx = iris.iris(1 << 20) >>> def test_fn(): >>> tensor = iris_ctx.zeros(1000, 1000) >>> time_ms = iris.do_bench(test_fn, barrier_fn=iris_ctx.barrier)
Broadcast Helper#
Broadcast a Python scalar or small object from a source rank to all ranks. This is a convenience wrapper over the internal Torch Distributed helper.
- Iris.broadcast(value, source_rank)[source]#
Broadcast a Python scalar or small picklable object from one rank to all ranks.
- Parameters:
value (Any) – The value to broadcast. Only the
source_rank
value is used; other ranks should pass a placeholder (e.g.,None
).source_rank (int) – Rank id that holds the authoritative value.
- Returns:
The value broadcast to all ranks.
- Return type:
Any
Example
>>> ctx = iris.iris() >>> value = 42 if ctx.cur_rank == 0 else None >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42