diff --git a/bench/bench_cuda_adam.py b/bench/bench_cuda_adam.py new file mode 100644 index 0000000..5d41a5d --- /dev/null +++ b/bench/bench_cuda_adam.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# bench/bench_cuda_adam.py +# +# Performance benchmark: CUDA Adam kernel vs Triton reference. +# +# Usage: +# python bench/bench_cuda_adam.py # all configs, default warmup/iters +# python bench/bench_cuda_adam.py --iters 200 --warmup 50 +# python bench/bench_cuda_adam.py --csv results.csv + +import argparse +import csv +import sys +import time +from dataclasses import dataclass, field +from typing import Optional + +import torch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, +} + + +def _make_state(N: int, device: str, quantize: bool, dtype: torch.dtype): + param = torch.randn(N, device=device, dtype=dtype) + grad = torch.randn(N, device=device, dtype=dtype) * 0.01 + if quantize: + G = (N + 31) // 32 + mom = torch.zeros(N, device=device, dtype=torch.int8) + mom_scales = torch.ones(G, device=device, dtype=torch.float16) * 0.01 + var = torch.zeros(N, device=device, dtype=torch.uint8) + var_scales = torch.ones(G, device=device, dtype=torch.float16) * 1e-4 + else: + mom = torch.zeros(N, device=device, dtype=dtype) + mom_scales = torch.empty(0, device=device, dtype=torch.float16) + var = torch.zeros(N, device=device, dtype=dtype) + var_scales = torch.empty(0, device=device, dtype=torch.float16) + return param, grad, mom, mom_scales, var, var_scales + + +def _run_triton(param, grad, mom, mom_scales, var, var_scales, + quantize, decoupled, step): + import flashoptim.optimizers as opt_mod + opt_mod._try_load_cuda_adam_ext() + orig, opt_mod._cuda_adam_ext = opt_mod._cuda_adam_ext, None + try: + opt_mod._fused_adam_step( + mom, mom_scales, var, var_scales, param, grad, None, + 1e-3, 0.9, 0.999, 1e-8, 0.01, decoupled, step, + quantize_optim_states=quantize, + ) + finally: + opt_mod._cuda_adam_ext = orig + + +def _run_cuda(param, grad, mom, mom_scales, var, var_scales, + quantize, decoupled, step): + import flashoptim._cuda_adam as ext + ext.adam_step( + mom, mom_scales, var, var_scales, param, grad, None, + 1e-3, 0.9, 0.999, 1e-8, 0.01, step, + quantize, decoupled, 32, + ) + + +def _bench(fn, warmup: int, iters: int) -> float: + """Return median wall-time per call in milliseconds (GPU-synchronised).""" + torch.cuda.synchronize() + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(iters): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + torch.cuda.synchronize() + times.append((time.perf_counter() - t0) * 1e3) + + times.sort() + n = len(times) + return times[n // 2] # median + + +# --------------------------------------------------------------------------- +# Benchmark configuration +# --------------------------------------------------------------------------- + +@dataclass +class BenchConfig: + N: int + dtype: str # "bf16" | "fp16" | "fp32" + quantize: bool + decoupled: bool + + +# All combinations to benchmark +CONFIGS = [ + BenchConfig(N=n, dtype=dt, quantize=q, decoupled=d) + for n in [4_096, 65_536, 1_048_576, 16_777_216] # 4K → 16M elements + for dt in ["bf16", "fp16"] + for q in [True, False] + for d in [True, False] +] + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="FlashAdam CUDA vs Triton benchmark") + parser.add_argument("--warmup", type=int, default=30, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=100, help="Timed iterations") + parser.add_argument("--csv", type=str, default=None, help="Save results to CSV") + parser.add_argument("--dtype", type=str, default=None, + help="Filter dtype (bf16|fp16|fp32)") + parser.add_argument("--n", type=int, default=None, + help="Filter N (exact match)") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("ERROR: No CUDA GPU available.", file=sys.stderr) + sys.exit(1) + + # Check extension + try: + import flashoptim._cuda_adam # noqa: F401 + except ImportError: + print("ERROR: flashoptim._cuda_adam not compiled. Run:\n" + " FLASHOPTIM_BUILD_CUDA=1 python setup.py build_ext --inplace", + file=sys.stderr) + sys.exit(1) + + import flashoptim.optimizers as opt_mod + opt_mod._try_load_cuda_adam_ext() + + # Pre-warm Triton JIT for all (dtype, quantize, decoupled) combos + # Use N=65536 to ensure the exact same Triton kernel (tiled for that size) is compiled. + device = "cuda" + print("Pre-warming Triton JIT...", end=" ", flush=True) + seen = set() + for cfg in CONFIGS: + key = (cfg.dtype, cfg.quantize, cfg.decoupled) + if key in seen: + continue + seen.add(key) + dtype = _DTYPE_MAP[cfg.dtype] + p, g, m, ms, v, vs = _make_state(65536, device, cfg.quantize, dtype) + for _ in range(3): + _run_triton(p.clone(), g, m.clone(), ms.clone(), v.clone(), vs.clone(), + cfg.quantize, cfg.decoupled, 1) + torch.cuda.synchronize() + print("done") + gpu_name = torch.cuda.get_device_name(0) + p = torch.cuda.get_device_properties(0) + print(f"\n{'='*72}") + print(f" GPU : {gpu_name} (SM {p.major}.{p.minor})") + print(f" Warmup: {args.warmup} Timed: {args.iters}") + print(f"{'='*72}") + print(f"{'N':>12} {'dtype':>5} {'quant':>5} {'decoup':>6} " + f"{'Triton(ms)':>10} {'CUDA(ms)':>10} {'Speedup':>7} " + f"{'BW Triton':>10} {'BW CUDA':>10}") + print(f"{'-'*92}") + + configs = CONFIGS + if args.dtype: + configs = [c for c in configs if c.dtype == args.dtype] + if args.n: + configs = [c for c in configs if c.N == args.n] + + rows = [] + for cfg in configs: + dtype = _DTYPE_MAP[cfg.dtype] + param, grad, mom, ms, var, vs = _make_state(cfg.N, device, cfg.quantize, dtype) + + # Clones for each backend so state doesn't accumulate differences + def triton_step(step=[1]): + _run_triton(param.clone(), grad, mom.clone(), ms.clone(), + var.clone(), vs.clone(), cfg.quantize, cfg.decoupled, step[0]) + step[0] += 1 + + def cuda_step(step=[1]): + _run_cuda(param.clone(), grad, mom.clone(), ms.clone(), + var.clone(), vs.clone(), cfg.quantize, cfg.decoupled, step[0]) + step[0] += 1 + + t_triton = _bench(triton_step, args.warmup, args.iters) + t_cuda = _bench(cuda_step, args.warmup, args.iters) + speedup = t_triton / t_cuda if t_cuda > 0 else float("inf") + + # Approximate memory bandwidth (bytes read + written per step) + elem_bytes = 2 if cfg.dtype in ("bf16", "fp16") else 4 + if cfg.quantize: + # mom(i8) + var(u8) + mom_scales(f16) + var_scales(f16) + param + grad + G = (cfg.N + 31) // 32 + bw_elems = 2 * cfg.N + 2 * G * 2 + 2 * cfg.N * elem_bytes + else: + bw_elems = 4 * cfg.N * elem_bytes # mom + var + param + grad + bw_gb_triton = bw_elems / (t_triton * 1e-3) / 1e9 + bw_gb_cuda = bw_elems / (t_cuda * 1e-3) / 1e9 + + print(f"{cfg.N:>12,} {cfg.dtype:>5} {str(cfg.quantize):>5} " + f"{str(cfg.decoupled):>6} " + f"{t_triton:>10.3f} {t_cuda:>10.3f} {speedup:>7.2f}x " + f"{bw_gb_triton:>9.1f}G {bw_gb_cuda:>9.1f}G") + + rows.append({ + "gpu": gpu_name, + "N": cfg.N, + "dtype": cfg.dtype, + "quantize": cfg.quantize, + "decoupled": cfg.decoupled, + "triton_ms": round(t_triton, 4), + "cuda_ms": round(t_cuda, 4), + "speedup": round(speedup, 4), + "bw_triton_GBs": round(bw_gb_triton, 2), + "bw_cuda_GBs": round(bw_gb_cuda, 2), + }) + + print(f"{'='*92}\n") + + if args.csv: + with open(args.csv, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + print(f"Results saved to {args.csv}") + + +if __name__ == "__main__": + main() diff --git a/flashoptim/csrc/flash_adam_cuda.cu b/flashoptim/csrc/flash_adam_cuda.cu new file mode 100644 index 0000000..2bef56f --- /dev/null +++ b/flashoptim/csrc/flash_adam_cuda.cu @@ -0,0 +1,431 @@ +// SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// flash_adam_cuda.cu +// Fused Adam / AdamW CUDA kernel replacing the Triton implementation. +// + +#include "flash_adam_cuda.cuh" + +#include +#include +#include +#include +#include +#include + + +template +__device__ __forceinline__ void vec_load(const T* ptr, int base, bool valid, float out[VEC]) { + if (valid) { + if constexpr (sizeof(T) == 4) { + float4 v = *reinterpret_cast(ptr + base); + out[0] = v.x; out[1] = v.y; out[2] = v.z; out[3] = v.w; + } else if constexpr (sizeof(T) == 2) { + uint64_t raw; + *reinterpret_cast(&raw) = + *reinterpret_cast(ptr + base); + const T* vals = reinterpret_cast(&raw); + out[0] = param_to_float(vals[0]); + out[1] = param_to_float(vals[1]); + out[2] = param_to_float(vals[2]); + out[3] = param_to_float(vals[3]); + } + } else { + out[0] = out[1] = out[2] = out[3] = 0.f; + } +} + +__device__ __forceinline__ void vec_load_i8(const int8_t* ptr, int base, bool valid, float out[VEC]) { + if (valid) { + int32_t raw = *reinterpret_cast(ptr + base); // 32-bit = 4×int8 + out[0] = (float)((int8_t)(raw & 0xFF)); + out[1] = (float)((int8_t)((raw >> 8) & 0xFF)); + out[2] = (float)((int8_t)((raw >> 16) & 0xFF)); + out[3] = (float)((int8_t)((raw >> 24) & 0xFF)); + } else { + out[0] = out[1] = out[2] = out[3] = 0.f; + } +} + +/// Load VEC=4 uint8 values and convert to float. +__device__ __forceinline__ void vec_load_u8(const uint8_t* ptr, int base, bool valid, float out[VEC]) { + if (valid) { + uint32_t raw = *reinterpret_cast(ptr + base); + out[0] = (float)(raw & 0xFF); + out[1] = (float)((raw >> 8) & 0xFF); + out[2] = (float)((raw >> 16) & 0xFF); + out[3] = (float)((raw >> 24) & 0xFF); + } else { + out[0] = out[1] = out[2] = out[3] = 0.f; + } +} + +/// Store VEC=4 float values back to a typed array. +template +__device__ __forceinline__ void vec_store(T* ptr, int base, bool valid, const float src[VEC]) { + if (valid) { + if constexpr (sizeof(T) == 4) { + float4 v = {src[0], src[1], src[2], src[3]}; + *reinterpret_cast(ptr + base) = v; + } else if constexpr (sizeof(T) == 2) { + T vals[VEC] = {float_to_param(src[0]), float_to_param(src[1]), + float_to_param(src[2]), float_to_param(src[3])}; + *reinterpret_cast(ptr + base) = *reinterpret_cast(vals); + } + } +} + +/// Store VEC=4 float values as int8 (round-to-nearest, clamp to [-127,127]). +__device__ __forceinline__ void vec_store_i8(int8_t* ptr, int base, bool valid, const float src[VEC]) { + if (valid) { + int32_t out = 0; + for (int i = 0; i < VEC; i++) { + int v = __float2int_rn(src[i]); + v = max(-127, min(127, v)); + out |= ((uint8_t)(int8_t)v) << (i * 8); + } + *reinterpret_cast(ptr + base) = out; + } +} + +/// Store VEC=4 float values as uint8 (round-to-nearest, clamp to [0,255]). +__device__ __forceinline__ void vec_store_u8(uint8_t* ptr, int base, bool valid, const float src[VEC]) { + if (valid) { + uint32_t out = 0; + for (int i = 0; i < VEC; i++) { + unsigned int v = __float2uint_rn(src[i]); + v = min(v, 255u); + out |= v << (i * 8); + } + *reinterpret_cast(ptr + base) = out; + } +} + + +template < + typename ParamT, // __nv_bfloat16 | __half | float + int GROUP_SIZE, // 32 (must equal warp size for warp-reduce trick) + int BLOCK_SIZE, // threads per block (must be multiple of GROUP_SIZE) + bool kQuantize, // INT8 optimizer states + bool kDecoupled, // AdamW decoupled weight decay + bool kUseECC, // ECC error correction bits + int kECCBits // 8 or 16 +> +__global__ void flash_adam_kernel( + int8_t* __restrict__ mom_ptr, + uint8_t* __restrict__ var_ptr, + __half* __restrict__ mom_scales_ptr, + __half* __restrict__ var_scales_ptr, + ParamT* __restrict__ param_ptr, + const ParamT* __restrict__ grad_ptr, + void* __restrict__ ecc_ptr, + int N, + float lr, + float beta1, + float beta2, + float eps, + float weight_decay, + int step +) { + static_assert(GROUP_SIZE == 32, "GROUP_SIZE must equal warp size (32)"); + static_assert(BLOCK_SIZE % GROUP_SIZE == 0, "BLOCK_SIZE must be multiple of GROUP_SIZE"); + + constexpr int ELEMS_PER_BLOCK = BLOCK_SIZE * VEC; + + const int tid = threadIdx.x; + const int lane = tid % 32; + + const float bc1 = 1.f - __powf(beta1, (float)step); + const float bc2 = 1.f - __powf(beta2, (float)step); + + for (int block_start = (int)blockIdx.x * ELEMS_PER_BLOCK; + block_start < N; + block_start += (int)gridDim.x * ELEMS_PER_BLOCK) + { + const int elem_base = block_start + tid * VEC; + const bool in_bounds = (elem_base + VEC - 1) < N; + const bool any_valid = (elem_base < N); + + if (!any_valid) continue; // entire thread is out of bounds + + float grad_f[VEC], param_f[VEC]; + vec_load(grad_ptr, elem_base, in_bounds || any_valid, grad_f); + vec_load(param_ptr, elem_base, in_bounds || any_valid, param_f); + + // Mask out-of-bounds elements for tail correctness + if (!in_bounds && any_valid) { + for (int i = 0; i < VEC; i++) { + if (elem_base + i >= N) { + grad_f[i] = 0.f; + param_f[i] = 0.f; + } + } + } + if constexpr (kUseECC) { + using EccT = typename std::conditional::type; + const EccT* ecc_typed = reinterpret_cast(ecc_ptr); + for (int i = 0; i < VEC; i++) { + if (elem_base + i < N) { + EccT ecc_val = ecc_typed[elem_base + i]; + ParamT x_narrow = param_ptr[elem_base + i]; + param_f[i] = decode_ecc(x_narrow, (int)ecc_val); + } + } + } + + if constexpr (!kDecoupled) { + for (int i = 0; i < VEC; i++) + grad_f[i] += param_f[i] * weight_decay; + } + + float mom_f[VEC], var_f[VEC]; + + if constexpr (kQuantize) { + // Number of threads that cover one GROUP_SIZE-element group: + constexpr int THREADS_PER_GROUP = GROUP_SIZE / VEC; // 32/4 = 8 + + const int group_idx_in_block = tid / THREADS_PER_GROUP; + const int group_idx_global = block_start / GROUP_SIZE + group_idx_in_block; + const int num_total_groups = (N + GROUP_SIZE - 1) / GROUP_SIZE; + const bool group_valid = (group_idx_global < num_total_groups); + + constexpr unsigned int reduce_mask = 0xffffffffu; + const int group_lane = lane % THREADS_PER_GROUP; + const int group_start_lane = (lane / THREADS_PER_GROUP) * THREADS_PER_GROUP; + + // Load raw quantised values directly into mom_f/var_f to save registers + vec_load_i8(mom_ptr, elem_base, in_bounds || any_valid, mom_f); + vec_load_u8(var_ptr, elem_base, in_bounds || any_valid, var_f); + + // Load scales (only lane 0 of each group reads; broadcast via shuffle) + float mom_scale = 0.f, var_scale = 0.f; + if (group_lane == 0 && group_valid) { + mom_scale = __half2float(mom_scales_ptr[group_idx_global]); + var_scale = __half2float(var_scales_ptr[group_idx_global]); + } + mom_scale = __shfl_sync(reduce_mask, mom_scale, group_start_lane); + var_scale = __shfl_sync(reduce_mask, var_scale, group_start_lane); + + // Dequantise in-place + for (int i = 0; i < VEC; i++) { + float m_t = mom_f[i] / 127.f; + mom_f[i] = inv_softsign(m_t) * mom_scale; + + float v_t = var_f[i] / 255.f; + float vs = v_t * var_scale; + var_f[i] = vs * vs; // undo sqrt + } + } else { + // Full-precision states: load directly + if constexpr (sizeof(ParamT) == 4) { + vec_load(reinterpret_cast(mom_ptr), elem_base, in_bounds || any_valid, mom_f); + vec_load(reinterpret_cast(var_ptr), elem_base, in_bounds || any_valid, var_f); + } else { + vec_load(reinterpret_cast(mom_ptr), elem_base, in_bounds || any_valid, mom_f); + vec_load(reinterpret_cast(var_ptr), elem_base, in_bounds || any_valid, var_f); + } + } + + for (int i = 0; i < VEC; i++) { + // Update first moment + mom_f[i] = beta1 * mom_f[i] + (1.f - beta1) * grad_f[i]; + // Update second moment + var_f[i] = beta2 * var_f[i] + (1.f - beta2) * grad_f[i] * grad_f[i]; + } + + // Decoupled weight decay: param *= (1 - wd) + if constexpr (kDecoupled) { + const float wd_scale = 1.f - weight_decay; + for (int i = 0; i < VEC; i++) + param_f[i] *= wd_scale; + } + + // Bias-corrected param update + for (int i = 0; i < VEC; i++) { + float m_hat = mom_f[i] / bc1; + float v_hat = var_f[i] / bc2; + param_f[i] -= lr * m_hat / (sqrtf(v_hat) + eps); + } + + if (in_bounds) { + vec_store(param_ptr, elem_base, true, param_f); + } else { + for (int i = 0; i < VEC; i++) + if (elem_base + i < N) + param_ptr[elem_base + i] = float_to_param(param_f[i]); + } + + if constexpr (kUseECC) { + using EccT = typename std::conditional::type; + EccT* ecc_typed = reinterpret_cast(ecc_ptr); + for (int i = 0; i < VEC; i++) { + if (elem_base + i < N) { + ParamT x_narrow = param_ptr[elem_base + i]; + ecc_typed[elem_base + i] = + (EccT)encode_ecc(param_f[i], x_narrow); + } + } + } + + if constexpr (kQuantize) { + constexpr int THREADS_PER_GROUP = GROUP_SIZE / VEC; // 8 + + const int group_lane = lane % THREADS_PER_GROUP; + const int group_start_lane = (lane / THREADS_PER_GROUP) * THREADS_PER_GROUP; + const int group_idx_in_block = tid / THREADS_PER_GROUP; + const int group_idx_global = block_start / GROUP_SIZE + group_idx_in_block; + const int num_total_groups = (N + GROUP_SIZE - 1) / GROUP_SIZE; + const bool group_valid = (group_idx_global < num_total_groups); + + constexpr unsigned int reduce_mask = 0xffffffffu; + + // Compute sqrt(var) in-place, accumulate per-thread absmax + float mom_abs = 0.f, var_sqrt_abs = 0.f; + for (int i = 0; i < VEC; i++) { + mom_abs = fmaxf(mom_abs, fabsf(mom_f[i])); + var_f[i] = sqrtf(fmaxf(var_f[i], 0.f)); + var_sqrt_abs = fmaxf(var_sqrt_abs, var_f[i]); + } + + float group_mom = mom_abs, group_var = var_sqrt_abs; +#pragma unroll + for (int off = THREADS_PER_GROUP >> 1; off > 0; off >>= 1) { + group_mom = fmaxf(group_mom, __shfl_xor_sync(reduce_mask, group_mom, off)); + group_var = fmaxf(group_var, __shfl_xor_sync(reduce_mask, group_var, off)); + } + // Now group_start_lane holds the correct group max; broadcast to all in group. + mom_abs = fmaxf(__shfl_sync(reduce_mask, group_mom, group_start_lane), 1e-12f); + var_sqrt_abs = fmaxf(__shfl_sync(reduce_mask, group_var, group_start_lane), 1e-12f); + + // Quantise in-place: use rcp multiply instead of divide for speed + const float inv_mom_abs = 1.f / mom_abs; + const float inv_var_sqrt_abs = 1.f / var_sqrt_abs; + for (int i = 0; i < VEC; i++) { + mom_f[i] = softsign(mom_f[i] * inv_mom_abs) * 127.f; + var_f[i] = (var_f[i] * inv_var_sqrt_abs) * 255.f; + } + + if (in_bounds) { + vec_store_i8(mom_ptr, elem_base, true, mom_f); + vec_store_u8(var_ptr, elem_base, true, var_f); + } else { + for (int i = 0; i < VEC; i++) { + if (elem_base + i < N) { + int mv = __float2int_rn(mom_f[i]); + mom_ptr[elem_base + i] = (int8_t)max(-127, min(127, mv)); + unsigned int vv = __float2uint_rn(var_f[i]); + var_ptr[elem_base + i] = (uint8_t)min(vv, 255u); + } + } + } + + // Store scales (only the first thread of each group) + if (group_lane == 0 && group_valid) { + mom_scales_ptr[group_idx_global] = __float2half(mom_abs); + var_scales_ptr[group_idx_global] = __float2half(var_sqrt_abs); + } + } else { + // Store states at param precision + if (in_bounds) { + if constexpr (sizeof(ParamT) == 4) { + vec_store(reinterpret_cast(mom_ptr), elem_base, true, mom_f); + vec_store(reinterpret_cast(var_ptr), elem_base, true, var_f); + } else { + vec_store(reinterpret_cast(mom_ptr), elem_base, true, mom_f); + vec_store(reinterpret_cast(var_ptr), elem_base, true, var_f); + } + } else { + for (int i = 0; i < VEC; i++) { + if (elem_base + i < N) { + if constexpr (sizeof(ParamT) == 4) { + reinterpret_cast(mom_ptr)[elem_base + i] = mom_f[i]; + reinterpret_cast(var_ptr)[elem_base + i] = var_f[i]; + } else { + reinterpret_cast(mom_ptr)[elem_base + i] = float_to_param(mom_f[i]); + reinterpret_cast(var_ptr)[elem_base + i] = float_to_param(var_f[i]); + } + } + } + } + } + } // end grid-stride loop +} + + +// Macro to launch a particular instantiation +#define LAUNCH_KERNEL(ParamT, kQ, kD, kE, kB) \ + flash_adam_kernel \ + <<>>( \ + mom_ptr, var_ptr, mom_scales_ptr, var_scales_ptr, \ + reinterpret_cast(param_ptr), \ + reinterpret_cast(grad_ptr), \ + ecc_ptr, N, lr, beta1, beta2, eps, weight_decay, step) + +// Helper: dispatch over (kDecoupled, kUseECC, kECCBits) for a fixed ParamT and kQuantize +#define DISPATCH_FLAGS(ParamT, kQ) \ + do { \ + if (use_ecc && ecc_bits == 8) { \ + if (decoupled) { LAUNCH_KERNEL(ParamT, kQ, true, true, 8); } \ + else { LAUNCH_KERNEL(ParamT, kQ, false, true, 8); } \ + } else if (use_ecc && ecc_bits == 16) { \ + if (decoupled) { LAUNCH_KERNEL(ParamT, kQ, true, true, 16); } \ + else { LAUNCH_KERNEL(ParamT, kQ, false, true, 16); } \ + } else { \ + if (decoupled) { LAUNCH_KERNEL(ParamT, kQ, true, false, 8); } \ + else { LAUNCH_KERNEL(ParamT, kQ, false, false, 8); } \ + } \ + } while(0) + +void launch_flash_adam( + int8_t* mom_ptr, + uint8_t* var_ptr, + __half* mom_scales_ptr, + __half* var_scales_ptr, + void* param_ptr, + const void* grad_ptr, + void* ecc_ptr, + int param_dtype, + int N, + float lr, + float beta1, + float beta2, + float eps, + float weight_decay, + int step, + bool quantize, + bool decoupled, + bool use_ecc, + int ecc_bits, + int group_size, + cudaStream_t stream +) { + if (N == 0) return; + + // Cache SM count per device to avoid repeated cudaGetDeviceProperties calls. + static int cached_sm_count[64] = {}; // index by device id (up to 64 GPUs) + int device; + cudaGetDevice(&device); + if (cached_sm_count[device] == 0) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + cached_sm_count[device] = prop.multiProcessorCount; + } + const int sm_count = cached_sm_count[device]; + constexpr int BLOCK_SIZE = 256; + constexpr int ELEMS_PER_BLOCK = BLOCK_SIZE * VEC; // 256 * 4 = 1024 + int total_blocks = (N + ELEMS_PER_BLOCK - 1) / ELEMS_PER_BLOCK; + int grid = min(2 * sm_count, total_blocks); + + // Dispatch on param_dtype × quantize + if (param_dtype == 0) { // bf16 + if (quantize) { DISPATCH_FLAGS(__nv_bfloat16, true); } + else { DISPATCH_FLAGS(__nv_bfloat16, false); } + } else if (param_dtype == 1) { // fp16 + if (quantize) { DISPATCH_FLAGS(__half, true); } + else { DISPATCH_FLAGS(__half, false); } + } else { // fp32 + if (quantize) { DISPATCH_FLAGS(float, true); } + else { DISPATCH_FLAGS(float, false); } + } +} diff --git a/flashoptim/csrc/flash_adam_cuda.cuh b/flashoptim/csrc/flash_adam_cuda.cuh new file mode 100644 index 0000000..edc7ce9 --- /dev/null +++ b/flashoptim/csrc/flash_adam_cuda.cuh @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// flash_adam_cuda.cuh +// Shared device helpers and kernel declarations for the fused Adam CUDA kernel. +// + +#pragma once + +#include +#include +#include +#include +#include +#include // for std::min (host-side use in encode_ecc) + +// Number of elements loaded per thread per iteration (128-bit vector load). +// 4 × float32 = 16 bytes = one 128-bit transaction. +static constexpr int VEC = 4; + +#ifdef __CUDACC__ + +template struct ParamTypeTraits {}; + +template <> struct ParamTypeTraits<__nv_bfloat16> { + static constexpr int mantissa_bits = 7; + static constexpr int exponent_bits = 8; + static constexpr int exponent_bias = 127; + using uint_type = uint16_t; + __device__ static uint16_t bitcast_to_uint(__nv_bfloat16 x) { + return *reinterpret_cast(&x); + } +}; + +template <> struct ParamTypeTraits<__half> { + static constexpr int mantissa_bits = 10; + static constexpr int exponent_bits = 5; + static constexpr int exponent_bias = 15; + using uint_type = uint16_t; + __device__ static uint16_t bitcast_to_uint(__half x) { + return *reinterpret_cast(&x); + } +}; + +template <> struct ParamTypeTraits { + static constexpr int mantissa_bits = 23; + static constexpr int exponent_bits = 8; + static constexpr int exponent_bias = 127; + using uint_type = uint32_t; + __device__ static uint32_t bitcast_to_uint(float x) { + uint32_t u; + __builtin_memcpy(&u, &x, sizeof(u)); // avoids __float_as_uint which requires CUDA intrinsics + return u; + } +}; + +// --------------------------------------------------------------------------- +// Type-safe float ↔ ParamT conversion helpers +// --------------------------------------------------------------------------- + +/// Convert float → ParamT without relying on implicit __half constructor. +template +__device__ __forceinline__ ParamT float_to_param(float x); + +template <> +__device__ __forceinline__ __nv_bfloat16 float_to_param<__nv_bfloat16>(float x) { + return __float2bfloat16(x); +} + +template <> +__device__ __forceinline__ __half float_to_param<__half>(float x) { + return __float2half(x); +} + +template <> +__device__ __forceinline__ float float_to_param(float x) { + return x; +} + +/// Convert ParamT → float without relying on implicit __half cast. +template +__device__ __forceinline__ float param_to_float(ParamT x); + +template <> +__device__ __forceinline__ float param_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +template <> +__device__ __forceinline__ float param_to_float<__half>(__half x) { + return __half2float(x); +} + +template <> +__device__ __forceinline__ float param_to_float(float x) { + return x; +} + +/// Warp-wide absolute maximum reduction. +__device__ __forceinline__ float warp_absmax(float val) { + val = fabsf(val); + for (int offset = 16; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_xor_sync(0xffffffffu, val, offset)); + return val; +} + +/// Warp-wide maximum (for values that are always >= 0). +__device__ __forceinline__ float warp_max(float val) { + for (int offset = 16; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_xor_sync(0xffffffffu, val, offset)); + return val; +} + +__device__ __forceinline__ float softsign(float x) { + return 2.f * x / (1.f + fabsf(x)); +} + +__device__ __forceinline__ float inv_softsign(float y) { + return y / (2.f - fabsf(y)); +} + +template +__device__ __forceinline__ int get_unbiased_exponent(ParamT x) { + using Traits = ParamTypeTraits; + // abs via negation to avoid __CUDA_NO_HALF_OPERATORS__ issues + ParamT ax = (param_to_float(x) < 0.f) ? float_to_param(-param_to_float(x)) : x; + auto bits = Traits::bitcast_to_uint(ax); + int exp_bits = (int)(bits >> Traits::mantissa_bits); + return (exp_bits == 0) ? (1 - Traits::exponent_bias) + : (exp_bits - Traits::exponent_bias); +} + +template +__device__ __forceinline__ int log_ulp(ParamT x) { + return get_unbiased_exponent(x) - ParamTypeTraits::mantissa_bits; +} + +template +__device__ __forceinline__ int encode_ecc(float x_f32, ParamT x_narrow) { + constexpr int signed_max = (kECCBits == 8) ? 127 : 32767; + float x_recon = param_to_float(x_narrow); + float e = x_f32 - x_recon; + + int ls = log_ulp(x_narrow) - 1; + float neg_ls = (float)(-ls); + float h = floorf(neg_ls * 0.5f); + float temp = e * exp2f(h); + float e_norm = temp * exp2f(neg_ls - h); + float e_clamped = fmaxf(-1.f, fminf(1.f, e_norm)); + float scaled = e_clamped * (float)signed_max; + + float sign = (scaled >= 0.f) ? 1.f : -1.f; + int rounded = (int)(fabsf(scaled) + 0.5f); + rounded = min(rounded, signed_max); + return (int)(sign * (float)rounded); +} + +template +__device__ __forceinline__ float decode_ecc(ParamT x_narrow, int ecc_val) { + constexpr int signed_max = (kECCBits == 8) ? 127 : 32767; + float x_recon = param_to_float(x_narrow); + int ls = log_ulp(x_narrow) - 1; + float log_scale_f = (float)ls; + float h = floorf(log_scale_f * 0.5f); + float correction = ((float)ecc_val / (float)signed_max) * exp2f(h) * exp2f(log_scale_f - h); + return x_recon + correction; +} + +#endif // __CUDACC__ + +template < + typename ParamT, + int GROUP_SIZE, + int BLOCK_SIZE, + bool kQuantize, + bool kDecoupled, + bool kUseECC, + int kECCBits +> +__global__ void flash_adam_kernel( + int8_t* __restrict__ mom_ptr, + uint8_t* __restrict__ var_ptr, + __half* __restrict__ mom_scales_ptr, + __half* __restrict__ var_scales_ptr, + ParamT* __restrict__ param_ptr, + const ParamT* __restrict__ grad_ptr, + void* __restrict__ ecc_ptr, + int N, + float lr, + float beta1, + float beta2, + float eps, + float weight_decay, + int step +); + +void launch_flash_adam( + int8_t* mom_ptr, + uint8_t* var_ptr, + __half* mom_scales_ptr, + __half* var_scales_ptr, + void* param_ptr, + const void* grad_ptr, + void* ecc_ptr, + int param_dtype, + int N, + float lr, + float beta1, + float beta2, + float eps, + float weight_decay, + int step, + bool quantize, + bool decoupled, + bool use_ecc, + int ecc_bits, + int group_size, + cudaStream_t stream +); diff --git a/flashoptim/csrc/flash_adam_cuda_ext.cpp b/flashoptim/csrc/flash_adam_cuda_ext.cpp new file mode 100644 index 0000000..2e648ef --- /dev/null +++ b/flashoptim/csrc/flash_adam_cuda_ext.cpp @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// flash_adam_cuda_ext.cpp +// pybind11 / PyTorch C++ extension binding for the fused Adam CUDA kernel. +// + + +#include +#include // at::cuda::getCurrentCUDAStream +#include +#include "flash_adam_cuda.cuh" + +// Map a torch dtype to our internal integer code: +// 0 = bfloat16, 1 = float16, 2 = float32 +static int dtype_code(const torch::Tensor& t) { + if (t.dtype() == torch::kBFloat16) return 0; + if (t.dtype() == torch::kFloat16) return 1; + if (t.dtype() == torch::kFloat32) return 2; + TORCH_CHECK(false, "flash_adam: unsupported param dtype ", t.dtype()); +} + +/// Main entry point called from Python. +/// +/// Arguments match `_fused_adam_step` in optimizers.py exactly so that the +/// Python dispatcher can call this with zero overhead. +/// +/// mom – int8 tensor (quantized) or ParamT tensor (full-precision) +/// mom_scales – float16 tensor (only used when quantize=True) +/// var – uint8 tensor (quantized) or ParamT tensor (full-precision) +/// var_scales – float16 tensor (only used when quantize=True) +/// param – bf16/fp16/fp32 parameter tensor (in-place updated) +/// grad – gradient tensor (same dtype as param) +/// ecc – optional int8/int16 tensor, or empty tensor (USE_ECC=false) +void adam_step( + torch::Tensor mom, + torch::Tensor mom_scales, + torch::Tensor var, + torch::Tensor var_scales, + torch::Tensor param, + torch::Tensor grad, + torch::optional ecc, + double lr, + double beta1, + double beta2, + double eps, + double weight_decay, + int64_t step, + bool quantize, + bool decoupled, + int64_t group_size +) { + TORCH_CHECK(param.is_cuda(), "flash_adam: param must be a CUDA tensor"); + TORCH_CHECK(param.is_contiguous(), "flash_adam: param must be contiguous"); + TORCH_CHECK(grad.is_contiguous(), "flash_adam: grad must be contiguous"); + TORCH_CHECK(mom.is_contiguous(), "flash_adam: mom must be contiguous"); + TORCH_CHECK(var.is_contiguous(), "flash_adam: var must be contiguous"); + TORCH_CHECK(group_size == 32, "flash_adam: only group_size=32 is supported"); + + const int N = (int)param.numel(); + const int pdtype = dtype_code(param); + + bool use_ecc = false; + int ecc_bits = 8; + void* ecc_ptr = nullptr; + + if (ecc.has_value() && ecc->defined() && ecc->numel() > 0) { + use_ecc = true; + TORCH_CHECK(ecc->is_contiguous(), "flash_adam: ecc must be contiguous"); + if (ecc->dtype() == torch::kInt8) { + ecc_bits = 8; + } else if (ecc->dtype() == torch::kInt16) { + ecc_bits = 16; + } else { + TORCH_CHECK(false, "flash_adam: ecc must be int8 or int16"); + } + ecc_ptr = ecc->data_ptr(); + } + + // Resolve the current CUDA stream so the kernel is enqueued correctly + cudaStream_t stream = at::cuda::getCurrentCUDAStream(param.device().index()); + + launch_flash_adam( + quantize ? mom.data_ptr() : reinterpret_cast(mom.data_ptr()), + quantize ? var.data_ptr() : reinterpret_cast(var.data_ptr()), + quantize ? reinterpret_cast<__half*>(mom_scales.data_ptr()) : nullptr, + quantize ? reinterpret_cast<__half*>(var_scales.data_ptr()) : nullptr, + param.data_ptr(), + grad.data_ptr(), + ecc_ptr, + pdtype, + N, + (float)lr, + (float)beta1, + (float)beta2, + (float)eps, + (float)weight_decay, + (int)step, + quantize, + decoupled, + use_ecc, + ecc_bits, + (int)group_size, + stream + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashOptim fused Adam CUDA kernel"; + m.def( + "adam_step", + &adam_step, + "Fused Adam/AdamW step (CUDA)", + py::arg("mom"), + py::arg("mom_scales"), + py::arg("var"), + py::arg("var_scales"), + py::arg("param"), + py::arg("grad"), + py::arg("ecc"), + py::arg("lr"), + py::arg("beta1"), + py::arg("beta2"), + py::arg("eps"), + py::arg("weight_decay"), + py::arg("step"), + py::arg("quantize") = true, + py::arg("decoupled") = false, + py::arg("group_size") = 32 + ); +} diff --git a/flashoptim/optimizers.py b/flashoptim/optimizers.py index dbb49c7..7e4f800 100644 --- a/flashoptim/optimizers.py +++ b/flashoptim/optimizers.py @@ -34,6 +34,23 @@ import triton.language.extra.libdevice as libdevice +_cuda_adam_ext = None +_cuda_adam_load_attempted = False + + +def _try_load_cuda_adam_ext(): + """Lazily load the CUDA Adam extension on first use.""" + global _cuda_adam_ext, _cuda_adam_load_attempted + if _cuda_adam_load_attempted: + return + _cuda_adam_load_attempted = True + try: + import flashoptim._cuda_adam as _ext # noqa: F401 + _cuda_adam_ext = _ext + except ImportError: + pass # silently fall back to Triton + + class NumericsError(RuntimeError): """The optimizer detected that the learning rate is too small to meaningfully update weights at their current magnitude and dtype.""" @@ -2221,6 +2238,29 @@ def _fused_adam_step( N = param.numel() if N == 0: return + + _try_load_cuda_adam_ext() + if _cuda_adam_ext is not None: + _cuda_adam_ext.adam_step( + mom, + mom_scales_f16 if quantize_optim_states else mom, + var, + var_scales_f16 if quantize_optim_states else var, + param, + grad, + errors, + lr, + beta1, + beta2, + eps, + weight_decay, + step, + quantize_optim_states, + decoupled, + group_size, + ) + return + use_ecc, signed_max_val, errors, signed_error_t = _ecc_kernel_params(errors, param) grid = functools.partial(_make_grid, N) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..40a36dd --- /dev/null +++ b/setup.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# setup.py +# +# Builds the optional CUDA extension `flashoptim._cuda_adam` when NVCC is +# available. The pure-Python / Triton path remains fully functional without it. +# +# Build: +# pip install -e . # Python + Triton only (no nvcc needed) +# FLASHOPTIM_BUILD_CUDA=1 pip install -e . # force CUDA ext even if nvcc +# # detection would skip it +# python setup.py build_ext --inplace # build in-place for dev + +import os +import shutil +import sys + +from setuptools import setup + + +def _nvcc_available() -> bool: + if os.environ.get("FLASHOPTIM_BUILD_CUDA", "0") == "1": + return True + return shutil.which("nvcc") is not None + + +def _get_gencode_flags(): + """ + Build -gencode flags. We always include a broad baseline set and then + attempt to add the SM of whatever GPU is currently installed so that + the native cubin is embedded (faster JIT + guarantees no 'no kernel image' + errors at runtime). + """ + sm_set = {80, 86, 89, 90, 100} + + try: + import torch + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + p = torch.cuda.get_device_properties(i) + sm = p.major * 10 + p.minor + sm_set.add(sm) + except Exception: + pass + + flags = [] + for sm in sorted(sm_set): + flags.append(f"-gencode=arch=compute_{sm},code=sm_{sm}") + # Also add a PTX target for forward-compatibility with future GPUs + max_sm = max(sm_set) + flags.append(f"-gencode=arch=compute_{max_sm},code=compute_{max_sm}") + return flags + + +ext_modules = [] + +if _nvcc_available(): + try: + from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + cuda_ext = CUDAExtension( + name="flashoptim._cuda_adam", + sources=[ + "flashoptim/csrc/flash_adam_cuda.cu", + "flashoptim/csrc/flash_adam_cuda_ext.cpp", + ], + include_dirs=["flashoptim/csrc"], + extra_compile_args={ + "nvcc": [ + "-O3", + "--ftz=true", + "--prec-div=false", + "-std=c++17", + "--expt-relaxed-constexpr", + "--extended-lambda", + "-lineinfo", + ] + _get_gencode_flags(), + "cxx": ["-O3", "-std=c++17"], + }, + ) + ext_modules.append(cuda_ext) + + setup( + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + ) + + except Exception as exc: + print( + f"[flashoptim] WARNING: Could not configure CUDA extension ({exc}). " + "Falling back to Triton-only install.", + file=sys.stderr, + ) + setup() +else: + setup() diff --git a/test/test_cuda_adam.py b/test/test_cuda_adam.py new file mode 100644 index 0000000..b503222 --- /dev/null +++ b/test/test_cuda_adam.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright 2026 Databricks, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# test/test_cuda_adam.py +# +# Correctness tests for the CUDA Adam kernel. +# + + +import math + +import pytest +import torch + + +CUDA_AVAILABLE = torch.cuda.is_available() +CUDA_EXT_AVAILABLE = False +if CUDA_AVAILABLE: + try: + import flashoptim._cuda_adam # noqa: F401 + CUDA_EXT_AVAILABLE = True + except ImportError: + pass + +requires_cuda_ext = pytest.mark.skipif( + not (CUDA_AVAILABLE and CUDA_EXT_AVAILABLE), + reason="CUDA GPU + compiled flashoptim._cuda_adam extension required", +) +requires_cuda = pytest.mark.skipif( + not CUDA_AVAILABLE, + reason="CUDA GPU required", +) + + +def _make_state(N, device, quantize, dtype=torch.bfloat16, seed=42): + """Return (param, grad, mom, mom_scales, var, var_scales).""" + rng = torch.Generator(device=device).manual_seed(seed) + param = torch.randn(N, device=device, dtype=dtype, generator=rng) + grad = torch.randn(N, device=device, dtype=dtype, generator=rng) * 0.01 + if quantize: + G = (N + 31) // 32 + mom = torch.randint(-10, 10, (N,), device=device, dtype=torch.int8) + mom_scales = (torch.rand(G, device=device, generator=rng, dtype=torch.float32) * 0.1 + 0.01).half() + var = torch.randint(10, 50, (N,), device=device, dtype=torch.uint8) + var_scales = (torch.rand(G, device=device, generator=rng, dtype=torch.float32) * 0.01 + 1e-4).half() + else: + mom = torch.zeros(N, device=device, dtype=dtype) + mom_scales = torch.empty(0, device=device, dtype=torch.float16) + var = torch.zeros(N, device=device, dtype=dtype) + var_scales = torch.empty(0, device=device, dtype=torch.float16) + return param, grad, mom, mom_scales, var, var_scales + + +def _fp32_adam_step(param_fp32, grad_fp32, mom_fp32, var_fp32, + lr, beta1, beta2, eps, wd, step, decoupled): + """Reference Adam step entirely in fp32 (no quantization).""" + if decoupled: + param_fp32.mul_(1.0 - wd) + else: + grad_fp32 = grad_fp32 + param_fp32 * wd + + mom_fp32.mul_(beta1).add_(grad_fp32, alpha=1.0 - beta1) + var_fp32.mul_(beta2).addcmul_(grad_fp32, grad_fp32, value=1.0 - beta2) + + bc1 = 1.0 - beta1 ** step + bc2 = 1.0 - beta2 ** step + m_hat = mom_fp32 / bc1 + v_hat = var_fp32 / bc2 + param_fp32.addcdiv_(m_hat, v_hat.sqrt().add_(eps), value=-lr) + + +def _run_cuda(param, grad, mom, mom_scales, var, var_scales, + lr, beta1, beta2, eps, wd, step, quantize, decoupled): + import flashoptim._cuda_adam as ext + ext.adam_step( + mom, mom_scales, var, var_scales, + param, grad, None, + lr, beta1, beta2, eps, wd, step, + quantize, decoupled, 32, + ) + + +def _run_triton(param, grad, mom, mom_scales, var, var_scales, + lr, beta1, beta2, eps, wd, step, quantize, decoupled): + import flashoptim.optimizers as opt_mod + opt_mod._try_load_cuda_adam_ext() + orig, opt_mod._cuda_adam_ext = opt_mod._cuda_adam_ext, None + try: + opt_mod._fused_adam_step( + mom, mom_scales, var, var_scales, + param, grad, None, + lr, beta1, beta2, eps, wd, decoupled, step, + quantize_optim_states=quantize, + ) + finally: + opt_mod._cuda_adam_ext = orig + + +# --------------------------------------------------------------------------- +# 1. CUDA vs fp32 ground truth – param value +# --------------------------------------------------------------------------- + +@requires_cuda_ext +@pytest.mark.parametrize("N", [32, 128, 1024, 4096]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("decoupled", [True, False]) +def test_cuda_param_vs_fp32_no_quant(N, dtype, decoupled): + """Without quantization, CUDA param update matches fp32 reference closely.""" + device = "cuda" + lr, beta1, beta2, eps, wd, step = 1e-3, 0.9, 0.999, 1e-8, 0.01, 1 + + param_c, grad_c, mom_c, ms_c, var_c, vs_c = _make_state(N, device, False, dtype) + param_fp32 = param_c.float().clone() + grad_fp32 = grad_c.float().clone() + mom_fp32 = mom_c.float().clone() + var_fp32 = var_c.float().clone() + + _run_cuda(param_c, grad_c, mom_c, ms_c, var_c, vs_c, + lr, beta1, beta2, eps, wd, step, False, decoupled) + _fp32_adam_step(param_fp32, grad_fp32, mom_fp32, var_fp32, + lr, beta1, beta2, eps, wd, step, decoupled) + + # bf16/fp16 have ~3e-3 relative precision; allow a few ulps of rounding + atol = 5e-3 + torch.testing.assert_close( + param_c.float(), param_fp32, + atol=atol, rtol=1e-2, + msg=f"param vs fp32 ref: N={N}, dtype={dtype}, decoupled={decoupled}", + ) + + +# --------------------------------------------------------------------------- +# 2. CUDA optimizer states vs fp32 ground truth (no quantization) +# --------------------------------------------------------------------------- + +@requires_cuda_ext +@pytest.mark.parametrize("N", [128, 1024]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("decoupled", [True, False]) +def test_cuda_states_vs_fp32_no_quant(N, dtype, decoupled): + """Without quantization, CUDA mom/var states match fp32 reference.""" + device = "cuda" + lr, beta1, beta2, eps, wd, step = 1e-3, 0.9, 0.999, 1e-8, 0.01, 1 + + param_c, grad_c, mom_c, ms_c, var_c, vs_c = _make_state(N, device, False, dtype) + param_fp32 = param_c.float().clone() + grad_fp32 = grad_c.float().clone() + mom_fp32 = mom_c.float().clone() + var_fp32 = var_c.float().clone() + + _run_cuda(param_c, grad_c, mom_c, ms_c, var_c, vs_c, + lr, beta1, beta2, eps, wd, step, False, decoupled) + _fp32_adam_step(param_fp32, grad_fp32, mom_fp32, var_fp32, + lr, beta1, beta2, eps, wd, step, decoupled) + + # mom/var stored at param dtype; allow rounding + atol = 1e-3 + torch.testing.assert_close( + mom_c.float(), mom_fp32, atol=atol, rtol=1e-2, + msg=f"mom vs fp32: N={N}, dtype={dtype}, decoupled={decoupled}", + ) + torch.testing.assert_close( + var_c.float(), var_fp32, atol=atol, rtol=1e-2, + msg=f"var vs fp32: N={N}, dtype={dtype}, decoupled={decoupled}", + ) + + +# --------------------------------------------------------------------------- +# 3. CUDA vs Triton agreement (regression guard for all flag combos) +# --------------------------------------------------------------------------- + +@requires_cuda_ext +@pytest.mark.parametrize("N", [32, 128, 1024, 4096, 10001]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("quantize", [True, False]) +@pytest.mark.parametrize("decoupled", [True, False]) +def test_cuda_vs_triton_param(N, dtype, quantize, decoupled): + """CUDA and Triton produce the same updated param (within tolerance).""" + device = "cuda" + lr, beta1, beta2, eps, wd, step = 1e-3, 0.9, 0.999, 1e-8, 0.01, 1 + + param_t, grad_t, mom_t, ms_t, var_t, vs_t = _make_state(N, device, quantize, dtype) + param_c = param_t.clone(); grad_c = grad_t.clone() + mom_c = mom_t.clone(); ms_c = ms_t.clone() + var_c = var_t.clone(); vs_c = vs_t.clone() + + _run_triton(param_t, grad_t, mom_t, ms_t, var_t, vs_t, + lr, beta1, beta2, eps, wd, step, quantize, decoupled) + _run_cuda(param_c, grad_c, mom_c, ms_c, var_c, vs_c, + lr, beta1, beta2, eps, wd, step, quantize, decoupled) + + atol = 1e-2 if quantize else 1e-4 + rtol = 1e-2 if quantize else 1e-3 + torch.testing.assert_close( + param_c.float(), param_t.float(), atol=atol, rtol=rtol, + msg=f"param CUDA vs Triton: N={N}, dtype={dtype}, q={quantize}, d={decoupled}", + ) + + +@requires_cuda_ext +@pytest.mark.parametrize("N", [128, 1024]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("decoupled", [True, False]) +def test_cuda_vs_triton_quant_states(N, dtype, decoupled): + """With quantization: CUDA and Triton int8/uint8 states are within ±1 LSB, + and scales are close.""" + device = "cuda" + lr, beta1, beta2, eps, wd, step = 1e-3, 0.9, 0.999, 1e-8, 0.01, 1 + + param_t, grad_t, mom_t, ms_t, var_t, vs_t = _make_state(N, device, True, dtype) + param_c = param_t.clone(); grad_c = grad_t.clone() + mom_c = mom_t.clone(); ms_c = ms_t.clone() + var_c = var_t.clone(); vs_c = vs_t.clone() + + _run_triton(param_t, grad_t, mom_t, ms_t, var_t, vs_t, + lr, beta1, beta2, eps, wd, step, True, decoupled) + _run_cuda(param_c, grad_c, mom_c, ms_c, var_c, vs_c, + lr, beta1, beta2, eps, wd, step, True, decoupled) + + diff_mom = (mom_c.int() - mom_t.int()).abs().max().item() + diff_var = (var_c.int() - var_t.int()).abs().max().item() + assert diff_mom <= 1, f"mom int8 mismatch > 1 LSB: max_diff={diff_mom}, dtype={dtype}, d={decoupled}" + assert diff_var <= 1, f"var uint8 mismatch > 1 LSB: max_diff={diff_var}, dtype={dtype}, d={decoupled}" + + torch.testing.assert_close(ms_c, ms_t, atol=1e-3, rtol=1e-2, msg="mom_scales mismatch") + torch.testing.assert_close(vs_c, vs_t, atol=1e-3, rtol=1e-2, msg="var_scales mismatch") + + +# --------------------------------------------------------------------------- +# 4. Tail / boundary sizes +# --------------------------------------------------------------------------- + +@requires_cuda_ext +@pytest.mark.parametrize("N", [1, 3, 31, 33, 63, 65, 127, 129, 255, 257, 10001]) +@pytest.mark.parametrize("quantize", [True, False]) +def test_cuda_vs_triton_boundary_sizes(N, quantize): + """CUDA handles non-aligned tensor sizes correctly (tail elements).""" + device = "cuda" + dtype = torch.bfloat16 + lr, beta1, beta2, eps, wd, step = 1e-3, 0.9, 0.999, 1e-8, 0.01, 1 + + param_t, grad_t, mom_t, ms_t, var_t, vs_t = _make_state(N, device, quantize, dtype) + param_c = param_t.clone(); grad_c = grad_t.clone() + mom_c = mom_t.clone(); ms_c = ms_t.clone() + var_c = var_t.clone(); vs_c = vs_t.clone() + + _run_triton(param_t, grad_t, mom_t, ms_t, var_t, vs_t, + lr, beta1, beta2, eps, wd, step, quantize, True) + _run_cuda(param_c, grad_c, mom_c, ms_c, var_c, vs_c, + lr, beta1, beta2, eps, wd, step, quantize, True) + + atol = 1e-2 if quantize else 1e-4 + torch.testing.assert_close( + param_c.float(), param_t.float(), atol=atol, rtol=1e-2, + msg=f"boundary N={N}, quantize={quantize}", + ) + + +# --------------------------------------------------------------------------- +# 5. Multi-step numerical stability and drift +# --------------------------------------------------------------------------- + +@requires_cuda_ext +@pytest.mark.parametrize("quantize", [True, False]) +@pytest.mark.parametrize("decoupled", [True, False]) +def test_cuda_multi_step_vs_triton(quantize, decoupled): + """Over 20 steps, CUDA and Triton produce identical params (within ±1 quant LSB per step).""" + N, device, dtype = 4096, "cuda", torch.bfloat16 + lr, beta1, beta2, eps, wd = 1e-3, 0.9, 0.999, 1e-8, 0.01 + + param_t, grad_t, mom_t, ms_t, var_t, vs_t = _make_state(N, device, quantize, dtype) + param_c = param_t.clone(); mom_c = mom_t.clone(); ms_c = ms_t.clone() + var_c = var_t.clone(); vs_c = vs_t.clone() + + for step in range(1, 21): + grad_new = torch.randn(N, device=device, dtype=dtype, + generator=torch.Generator(device).manual_seed(step)) * 0.01 + _run_triton(param_t, grad_new, mom_t, ms_t, var_t, vs_t, + lr, beta1, beta2, eps, wd, step, quantize, decoupled) + _run_cuda(param_c, grad_new, mom_c, ms_c, var_c, vs_c, + lr, beta1, beta2, eps, wd, step, quantize, decoupled) + + assert not param_c.isnan().any(), "NaN in CUDA param after 20 steps" + assert not param_c.isinf().any(), "Inf in CUDA param after 20 steps" + + # CUDA and Triton should agree closely; quantization adds ~1 LSB per step + atol = 1e-2 if quantize else 1e-4 + torch.testing.assert_close( + param_c.float(), param_t.float(), atol=atol, rtol=1e-2, + msg=f"20-step CUDA vs Triton: quantize={quantize}, decoupled={decoupled}", + ) + + +# --------------------------------------------------------------------------- +# 6. Integration: FlashAdamW dispatches to CUDA ext +# --------------------------------------------------------------------------- + +@requires_cuda_ext +def test_flash_adamw_uses_cuda_ext(monkeypatch): + """FlashAdamW dispatches to the CUDA extension when available.""" + import flashoptim + import flashoptim.optimizers as opt_mod + + opt_mod._try_load_cuda_adam_ext() + + calls = [] + orig_step = opt_mod._cuda_adam_ext.adam_step + + def _spy(*args, **kwargs): + calls.append(1) + return orig_step(*args, **kwargs) + + monkeypatch.setattr(opt_mod._cuda_adam_ext, "adam_step", _spy) + monkeypatch.setattr(opt_mod, "_cuda_adam_load_attempted", True) + + model = torch.nn.Linear(64, 64, bias=False).cuda().bfloat16() + optimizer = flashoptim.FlashAdamW(model.parameters(), lr=1e-3) + x = torch.randn(8, 64, device="cuda", dtype=torch.bfloat16) + model(x).sum().backward() + optimizer.step() + + assert len(calls) > 0, "CUDA extension was not called during optimizer.step()" + + +# --------------------------------------------------------------------------- +# 7. Smoke test: extension imports cleanly +# --------------------------------------------------------------------------- + +@requires_cuda +def test_cuda_ext_importable(): + if not CUDA_EXT_AVAILABLE: + pytest.skip("flashoptim._cuda_adam not compiled") + import flashoptim._cuda_adam as ext # noqa: F401 + assert hasattr(ext, "adam_step")