Fine-grained GEMM & Communication Overlap

Contents

Fine-grained GEMM & Communication Overlap

Fine-grained GEMM & Communication Overlap#

dist_gemm dist_gemm

Algorithms#

For GEMM + communication kernels, at the moment, we assume that:

\(C = A \times B\) where,

  • \(B\) (weights): sharded column/row-wise across GPUs,

  • \(A\) (activations): replicated across GPUs, and

  • \(C\) (activations output): replicated across GPUs.

Currently, there are two implementations:

  1. GEMM + All reduce Where \(B\) is partitioned row-wise and hence \(A\) is partitioned column-wise so that we have two tall skinny matrices producing a partial \(C\) with shape of \(M \times N\) and the all reduce kernel reduces the results across all GPUs or ranks (right figure).

all-reduce

  1. GEMM + All scatter Where \(B\) is partitioned column-wise and hence each rank produces non-overlapping columns in the output \(C\) matrix such that we only need all gather/scatter to broadcast the final result (left figure).

all-scatter