Triton Device Functions#

Device-side functions provided by Iris for remote memory operations and atomics.

Memory transfer operations#

load#

load(pointer, to_rank, from_rank, heap_bases, mask=None)

Loads a value from the specified rank’s memory location.

This function performs a memory read operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and loading data from the target memory location. If the from_rank and to_rank are the same, this function performs a local load operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.

  • from_rank (int) – The rank ID from which to read the data.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.

Returns:

The loaded value from the target memory location.

Return type:

Block

store#

store(pointer, value, from_rank, to_rank, heap_bases, mask=None)#

Writes data to the specified rank’s memory location.

This function performs a memory write operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and storing the provided data to the target memory location. If the from_rank and to_rank are the same, this function performs a local store operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • value (Block) – The tensor of elements to be stored.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the data will be written.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.

Returns:

None

get#

get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None)#

Copies data from the specified rank’s memory to the current rank’s local memory.

This function performs a memory read operation by translating the from_ptr from the current rank’s address space to the from_rank’s address space, loading data from the from_rank memory location, and storing it to the local to_ptr. If the from_rank is the same as the current rank, this function performs a local copy operation.

Parameters:
  • from_ptr (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the current rank’s address space that will be translated to the from_rank’s address space. Must be the current rank where the pointer is local.

  • to_ptr (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the current rank’s local memory where the data will be stored.

  • from_rank (int) – The from_rank ID from which to read the data.

  • to_rank (int) – The current rank ID where the data will be stored.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

Returns:

None

put#

put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None)#

Copies data from the current rank’s local memory to the specified rank’s memory. This function performs a memory write operation by loading data from the current rank’s from_ptr, translating the to_ptr from the current rank’s address space to the to_rank’s address space, and storing the data to the to_rank memory location. If the to_rank is the same as the current rank, this function performs a local copy operation.

Parameters:
  • from_ptr (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the current rank’s local memory from which to read data.

  • to_ptr (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the current rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • from_rank (int) – The current rank ID from which to read the data.

  • to_rank (int) – The to_rank ID to which the data will be written.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

Returns:

None

Atomic operations#

atomic_add#

atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic add at the specified rank’s memory location.

This function performs an atomic addition operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically adding the provided data to the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic addition operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block

atomic_sub#

atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Atomically subtracts data from the specified rank’s memory location.

This function performs an atomic subtraction operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically subtracting the provided data from the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic subtraction operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block) – The tensor of elements to be subtracted atomically.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. Defaults to “acq_rel”.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). Defaults to “gpu”.

Returns:

The value at the memory location before the atomic subtraction.

Return type:

Block

atomic_cas#

atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None)#

Atomically compares and exchanges the specified rank’s memory location.

This function performs an atomic compare-and-swap operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically comparing the current value with the expected value, then writing the new value if they match. If the from_rank and to_rank are the same, this function performs a local atomic compare-and-swap operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – Pointer in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • cmp (Block) – The expected value to be compared with the current value at the memory location.

  • val (Block) – The new value to be written if the compare succeeds.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. Defaults to “acq_rel”.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). Defaults to “gpu”.

Returns:

The value contained at the memory location before the atomic operation attempt.

Return type:

Block

atomic_xchg#

atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic exchange at the specified rank’s memory location.

This function performs an atomic exchange operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically exchanging the current value with the provided new value. If the from_rank and to_rank are the same, this function performs a local atomic exchange operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block

atomic_xor#

atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic xor at the specified rank’s memory location.

This function performs an atomic xor operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically xoring the provided data to the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic xor operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block

atomic_or#

atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic or at the specified rank’s memory location.

This function performs an atomic or operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically oring the provided data to the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic or operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block

atomic_and#

atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic and at the specified rank’s memory location.

This function performs an atomic and operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically anding the provided data to the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic and operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block

atomic_min#

atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic min at the specified rank’s memory location.

This function performs an atomic min operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically performing the min on the provided data to the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic min operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block

atomic_max#

atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None)#

Performs an atomic max at the specified rank’s memory location.

This function performs an atomic max operation by translating the pointer from the from_rank’s address space to the to_rank’s address space and atomically performing the max on the provided data to the to_rank memory location. If the from_rank and to_rank are the same, this function performs a local atomic max operation.

Parameters:
  • pointer (triton.PointerType, or block of dtype=triton.PointerType) – The memory locations in the from_rank’s address space that will be translated to the to_rank’s address space. Must be the current rank where the pointer is local.

  • val (Block of dtype=pointer.dtype.element_ty) – The values with which to perform the atomic operation.

  • from_rank (int) – The rank ID from which the pointer originates. Must be the current rank where the pointer is local.

  • to_rank (int) – The rank ID to which the atomic operation will be performed.

  • heap_bases (triton.PointerType) – Array containing the heap base addresses for all ranks.

  • mask (Block of triton.int1, optional) – If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.

  • sem (str, optional) – Specifies the memory semantics for the operation. Acceptable values are “acquire”, “release”, “acq_rel” (stands for “ACQUIRE_RELEASE”), and “relaxed”. If not provided, the function defaults to using “acq_rel” semantics.

  • scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

Returns:

The data stored at pointer before the atomic operation.

Return type:

Block