# SPDX-License-Identifier: MIT
# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# Copyright 2018-2020 Philippe Tillet
# Copyright 2020-2022 OpenAI
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import statistics
import math
import torch
def is_simulation_env() -> bool:
"""
Return True if running in a simulation environment (e.g. pre-silicon).
When True, Iris will force the torch allocator regardless of allocator_type.
Set IRIS_SIMULATION=1 (or "true"/"yes") to enable.
"""
val = os.environ.get("IRIS_SIMULATION", "").strip().lower()
return val in ("1", "true", "yes")
def get_simulation_device_id(local_rank: int) -> int:
"""
Get the device ID to use in simulation mode.
In simulation, multiple ranks may need to share the same physical device.
This function wraps the local_rank to ensure it's within available device bounds.
Args:
local_rank: The local rank from the environment
Returns:
Device ID that's guaranteed to be valid (wrapped if needed)
"""
import torch
num_devices = torch.cuda.device_count()
if num_devices == 0:
return 0 # Fallback if no devices detected
# Wrap to available devices - in simulation, multiple ranks can share device 0
return local_rank % num_devices
def get_device_id_for_rank(local_rank: int) -> int:
"""
Get the device ID to use for a given local rank.
In simulation mode, this automatically wraps the rank to handle multiple ranks
sharing a single GPU. In normal mode, it returns the local_rank as-is.
Args:
local_rank: The local rank from the environment (typically from LOCAL_RANK env var)
Returns:
Device ID that's guaranteed to be valid
Example:
>>> import iris
>>> local_rank = int(os.environ.get("LOCAL_RANK", 0))
>>> device_id = iris.get_device_id_for_rank(local_rank)
>>> torch.cuda.set_device(device_id)
"""
if is_simulation_env():
return get_simulation_device_id(local_rank)
else:
return local_rank
def get_empty_cache_for_benchmark():
cache_size = 256 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda")
def clear_cache(cache):
cache.zero_()
def create_timing_event():
return torch.cuda.Event(enable_timing=True)
def _quantile(a, q):
n = len(a)
a = sorted(a)
def get_quantile(q):
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return (1 - t) * a[lower] + t * a[upper]
return [get_quantile(q) for q in q]
def _summarize_statistics(times, quantiles, return_mode):
if quantiles is not None:
ret = _quantile(times, quantiles)
if len(ret) == 1:
ret = ret[0]
return ret
if return_mode == "all":
return times
elif return_mode == "min":
return min(times)
elif return_mode == "max":
return max(times)
elif return_mode == "mean":
return statistics.mean(times)
elif return_mode == "median":
return statistics.median(times)
[docs]
def do_bench(
fn,
barrier_fn=lambda: None,
preamble_fn=lambda: None,
n_warmup=25,
n_repeat=100,
quantiles=None,
return_mode="mean",
):
"""
Benchmark a function by timing its execution.
Args:
fn (callable): Function to benchmark.
barrier_fn (callable, optional): Function to call for synchronization. Default: no-op.
preamble_fn (callable, optional): Function to call before each execution. Default: no-op.
n_warmup (int, optional): Number of warmup iterations. Default: 25.
n_repeat (int, optional): Number of timing iterations. Default: 100.
quantiles (list, optional): Quantiles to return instead of summary statistic. Default: None.
return_mode (str, optional): Summary statistic to return ("mean", "min", "max", "median", "all"). Default: "mean".
Returns:
float or list: Timing result(s) in milliseconds.
Example:
>>> import iris
>>> iris_ctx = iris.iris(1 << 20)
>>> def test_fn():
>>> tensor = iris_ctx.zeros(1000, 1000)
>>> time_ms = iris.do_bench(test_fn, barrier_fn=iris_ctx.barrier)
"""
# Wait for anything that happened before
barrier_fn()
preamble_fn()
fn()
barrier_fn()
# Wait for all GPUs to finish their work
cache = get_empty_cache_for_benchmark()
start_event = [create_timing_event() for i in range(n_repeat)]
end_event = [create_timing_event() for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
barrier_fn() # Wait for all GPUs before we clear the cache
preamble_fn()
clear_cache(cache)
barrier_fn() # Wait for clearing the cache before launching any kernels
fn()
# Benchmark
for i in range(n_repeat):
barrier_fn() # Wait for all GPUs before we clear the cache
preamble_fn()
clear_cache(cache)
barrier_fn() # Wait for clearing the cache before launching any kernels
start_event[i].record()
fn()
end_event[i].record()
barrier_fn() # Record clocks barrier
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
return _summarize_statistics(times, quantiles, return_mode)