Skip to content

02 — Cooperative Groups

1. The Problem: CUDA's Synchronization Wall

Traditional CUDA synchronizes only within a thread block:

__global__ void naive_reduction(float* data, float* result) {
    __shared__ float sdata[256];
    int tid = threadIdx.x;
    sdata[tid] = data[blockIdx.x * 256 + tid];
    __syncthreads();   // only syncs 256 threads in this block

    // Reduction within block...
    // But to get the GLOBAL sum, you need another kernel launch!
}

The limitation:

Thread block size limit: 1024 threads
GPU SM count: 108 (H100/H200)
Concurrent threads: 108 × 2048 active = ~220,000 threads

To sync ALL threads: you MUST launch a second kernel
                     which requires CPU involvement + kernel overhead

Cooperative Groups break this wall by providing programmable synchronization at multiple hierarchical levels — within a warp, across a block, and across the entire grid — without leaving the kernel.


2. The Cooperative Groups Hierarchy

GRID (all blocks in a kernel launch)
├── THREAD BLOCK (256–1024 threads, one SM)
│   │
│   ├── WARP (32 threads, execute in lockstep)
│   │   │
│   │   └── TILED PARTITION (8, 16, or 32 threads)
│   │
│   └── COALESCED GROUP (active threads in a warp)
└── MULTI-GRID (multiple kernels, rarely used)

Each level has its own sync(), shuffle operations, and metadata.


3. Thread Block Group (Drop-in __syncthreads() replacement)

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void block_sync_example(float* data) {
    // Get handle to current thread block
    cg::thread_block block = cg::this_thread_block();

    __shared__ float sdata[1024];
    int tid = threadIdx.x;

    sdata[tid] = data[blockIdx.x * blockDim.x + tid];

    // Equivalent to __syncthreads() but more readable + composable
    block.sync();

    // Can also get metadata
    unsigned size   = block.size();          // blockDim.x * blockDim.y * blockDim.z
    unsigned rank   = block.thread_rank();   // unique ID within block (0..size-1)
    dim3     tidx   = block.thread_index();  // = threadIdx
    dim3     bidx   = block.group_index();   // = blockIdx
}

Advantage over raw __syncthreads(): the group is a first-class object you can pass to functions.

// Pass the sync group to helper functions — impossible with __syncthreads()
__device__ void reduce_in_block(cg::thread_block& block, float* sdata) {
    for (int s = block.size() / 2; s > 0; s >>= 1) {
        if (block.thread_rank() < s)
            sdata[block.thread_rank()] += sdata[block.thread_rank() + s];
        block.sync();   // proper sync within the group
    }
}

4. Tiled Partition — Divide a Block into Subgroups

__global__ void tiled_example(float* data) {
    cg::thread_block block = cg::this_thread_block();

    // Partition block into groups of 32 (warp-sized)
    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);

    // Or partition into groups of 16, 8, 4, or 2
    cg::thread_block_tile<16> half_warp = cg::tiled_partition<16>(block);
    cg::thread_block_tile<4>  quad      = cg::tiled_partition<4>(block);

    // Each tile syncs independently
    float val = data[threadIdx.x];

    // Warp-level reduction using shuffle within tile
    for (int offset = warp.size() / 2; offset > 0; offset >>= 1) {
        val += warp.shfl_down(val, offset);
    }

    // Only thread 0 of each warp has the full sum
    if (warp.thread_rank() == 0)
        atomicAdd(&result, val);   // one atomic per warp (not per thread!)
}

The tiled partition shuffle is much faster than shared memory for warp-level reductions: - No __shared__ allocation - No bank conflicts - 1 cycle latency (vs 4+ for shared memory)


5. Coalesced Groups — Handle Warp Divergence

When threads in a warp take different paths (diverge), some threads are inactive. Coalesced groups create a group of only the active threads:

__global__ void divergent_kernel(float* data, float* result, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= n) return;  // some threads exit early → divergence

    // Get a group of only threads that are still active here
    cg::coalesced_group active = cg::coalesced_threads();

    float val = data[idx];

    // Reduce across ONLY the active threads in this warp
    // (no waste from inactive thread slots)
    for (int offset = active.size() / 2; offset > 0; offset >>= 1) {
        val += active.shfl_down(val, offset);
    }

    if (active.thread_rank() == 0) {
        atomicAdd(result, val);
    }
}

Without coalesced groups, warp divergence forces inactive threads to serialize. Coalesced groups pack active threads and execute them efficiently.


6. Grid Group — Full Grid Synchronization

The most powerful feature: sync ALL threads across ALL blocks inside one kernel.

// Requires special launch: cudaLaunchCooperativeKernel
#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void grid_sync_kernel(float* data, float* partial, float* result) {
    cg::grid_group grid = cg::this_grid();

    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    // Phase 1: every thread computes its local value
    partial[idx] = compute(data[idx]);

    // === GRID-WIDE SYNCHRONIZATION ===
    grid.sync();   // all blocks across the entire GPU wait here

    // Phase 2: use results from Phase 1 (now safe, all computed)
    if (idx == 0) {
        float total = 0;
        for (int i = 0; i < gridDim.x * blockDim.x; i++)
            total += partial[i];
        *result = total;
    }
}

Without grid sync: You'd need:

kernel_phase1<<<grid, block>>>(data, partial);
cudaDeviceSynchronize();   // CPU round-trip
kernel_phase2<<<grid, block>>>(partial, result);

With grid sync: Single kernel, no CPU involvement, no kernel launch overhead.

Grid Group Launch (C++)

// MUST use cooperative launch API — not standard <<<grid, block>>>
void* args[] = {&data, &partial, &result};

// Check device supports cooperative launch
int supportsCoopLaunch;
cudaDeviceGetAttribute(&supportsCoopLaunch,
    cudaDevAttrCooperativeLaunch, device_id);
assert(supportsCoopLaunch == 1);

// Check max blocks for cooperative launch
int numBlocksPerSM;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
    &numBlocksPerSM, grid_sync_kernel, blockSize, 0);
int numSMs;
cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, device_id);
int maxBlocks = numBlocksPerSM * numSMs;

// Launch with cooperative API
cudaLaunchCooperativeKernel(
    (void*)grid_sync_kernel,
    dim3(maxBlocks),   // must not exceed max blocks!
    dim3(blockSize),
    args,
    0, stream
);

Critical constraint: Grid sync requires that all blocks fit on the GPU simultaneously — you cannot launch more blocks than the GPU can schedule at once.


7. Warp-Level Primitives via Cooperative Groups

Cooperative groups expose warp intrinsics cleanly:

__global__ void warp_ops_demo(float* data) {
    auto warp = cg::tiled_partition<32>(cg::this_thread_block());

    float val = data[threadIdx.x];
    int   rank = warp.thread_rank();

    // === SHUFFLE DOWN (reduction) ===
    // Each thread gets val from thread (rank + offset)
    float neighbor = warp.shfl_down(val, 16);   // get from thread+16
    val += neighbor;

    // === SHUFFLE XOR (butterfly reduction) ===
    // Pairs up threads by XOR of their rank
    for (int mask = 16; mask > 0; mask >>= 1)
        val += warp.shfl_xor(val, mask);

    // === BALLOT (which threads satisfy a condition?) ===
    unsigned active_mask = warp.ballot(val > 0.0f);
    // active_mask: bit i = 1 if thread i has val > 0

    // === ANY / ALL (warp-wide predicate) ===
    bool any_positive = warp.any(val > 0.0f);
    bool all_positive = warp.all(val > 0.0f);

    // === MATCH (find threads with same value) ===
    unsigned same_val = warp.match_any(val);
    // same_val: mask of all threads in warp with identical `val`
}

8. Practical Example: Parallel Prefix Scan with Grid Groups

Prefix scan (cumulative sum) is a fundamental parallel primitive. With grid sync, it can be done in a single kernel:

__global__ void prefix_scan(float* data, float* output, float* block_sums, int n) {
    auto grid  = cg::this_grid();
    auto block = cg::this_thread_block();
    auto warp  = cg::tiled_partition<32>(block);

    extern __shared__ float sdata[];
    int gid = blockIdx.x * blockDim.x + threadIdx.x;
    int tid = threadIdx.x;

    // Phase 1: Intra-block scan
    sdata[tid] = (gid < n) ? data[gid] : 0.0f;
    block.sync();

    // Up-sweep (reduce)
    for (int stride = 1; stride < blockDim.x; stride <<= 1) {
        int idx = (tid + 1) * stride * 2 - 1;
        if (idx < blockDim.x)
            sdata[idx] += sdata[idx - stride];
        block.sync();
    }

    // Save block sum
    if (tid == blockDim.x - 1)
        block_sums[blockIdx.x] = sdata[tid];

    // === WAIT FOR ALL BLOCKS TO FINISH PHASE 1 ===
    grid.sync();   // <-- the key grid-wide sync

    // Phase 2: Only block 0 scans the block_sums array
    if (blockIdx.x == 0) {
        // scan block_sums... (single block, no grid sync needed)
    }

    // === WAIT FOR PHASE 2 ===
    grid.sync();

    // Phase 3: Add prefix from block_sums to each element
    float prefix = (blockIdx.x > 0) ? block_sums[blockIdx.x - 1] : 0.0f;
    if (gid < n)
        output[gid] = sdata[tid] + prefix;
}

Before grid sync: this required three separate kernel launches with CPU synchronization between them.


9. Cooperative Groups in PyTorch Custom CUDA Extensions

// custom_kernel.cu — used via torch.utils.cpp_extension
#include <torch/extension.h>
#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void warp_reduce_kernel(
    const float* __restrict__ input,
    float* __restrict__ output,
    int n
) {
    auto warp = cg::tiled_partition<32>(cg::this_thread_block());
    int gid = blockIdx.x * blockDim.x + threadIdx.x;

    float val = (gid < n) ? input[gid] : 0.0f;

    // Warp reduce
    for (int offset = 16; offset > 0; offset >>= 1)
        val += warp.shfl_down(val, offset);

    // One atomic per warp
    if (warp.thread_rank() == 0)
        atomicAdd(output, val);
}

torch::Tensor warp_reduce(torch::Tensor input) {
    auto output = torch::zeros(1, input.options());
    int n = input.numel();
    int threads = 256;
    int blocks  = (n + threads - 1) / threads;
    warp_reduce_kernel<<<blocks, threads>>>(
        input.data_ptr<float>(), output.data_ptr<float>(), n
    );
    return output;
}

10. When to Use Each Group Type

Situation Use
Drop-in __syncthreads() replacement thread_block.sync()
Warp-level reduction (no shared mem) tiled_partition<32> + shfl_down
Sub-warp parallelism (e.g., tree traversal) tiled_partition<N> (N = 4, 8, 16)
Handle divergent threads cleanly coalesced_threads()
Two-phase algorithm without extra kernel this_grid().sync()
Algorithm with broadcast within small group tiled_partition<N>.shfl()
Count/find active threads in warp coalesced_threads().size()

References