diff --git a/flashoptim/optimizers.py b/flashoptim/optimizers.py index 3c8d281..6712da5 100644 --- a/flashoptim/optimizers.py +++ b/flashoptim/optimizers.py @@ -636,28 +636,42 @@ def _wrap_state_as_dtensor(state: dict[str, Any], param: torch.Tensor) -> None: mesh = param.device_mesh placements = param.placements + param_local_shape = param.to_local().shape - # Reject uneven shards - DTensor.from_local infers global shape as - # local_size * world_size, which is only correct for even splits. - for mesh_dim, placement in enumerate(placements): - if hasattr(placement, "dim"): - shard_dim = placement.dim - mesh_size = mesh.size(mesh_dim) - if param.shape[shard_dim] % mesh_size != 0: - raise ValueError( - f"DCP checkpointing requires evenly-sharded parameters, " - f"but parameter with shape {param.shape} is unevenly " - f"sharded on dim {shard_dim} across {mesh_size} ranks. " - f"Pad or reshape the parameter so that shape[{shard_dim}] " - f"is divisible by {mesh_size}." - ) + uneven = any( + param.shape[p.dim] % mesh.size(i) != 0 + for i, p in enumerate(placements) + if hasattr(p, "dim") + ) for key, val in state.items(): - if ( + if not ( isinstance(val, torch.Tensor) and not isinstance(val, DTensor) and val.dim() > 0 ): + continue + if val.shape == param_local_shape: + # Pass explicit shape/stride so wrapping is correct on uneven shards. + state[key] = DTensor.from_local( + val, + mesh, + placements, + shape=param.shape, + stride=param.stride(), + run_check=False, + ) + elif uneven: + # e.g. quantized state with a packed shape: default inference + # would scramble DCP on uneven shards, and we have no global + # shape for this layout. Fail loudly rather than silently. + raise ValueError( + f"Cannot safely wrap state tensor {key!r} of shape " + f"{tuple(val.shape)} as a DTensor for unevenly-sharded " + f"param of shape {tuple(param.shape)}: its layout does " + f"not match the param, so the global shape is unknown." + ) + else: state[key] = DTensor.from_local(val, mesh, placements) def _recompute_stats_for_param(self, p: torch.Tensor) -> None: @@ -988,7 +1002,7 @@ def state_dict(self): for param_number in group["params"]: assert isinstance(param_number, int) if param_number not in opt_state: - continue # frozen param, no optimizer state + continue # frozen or pre-first-step param, no optimizer state opt_state[param_number] = self._state_dict_for_param( param_number, opt_state=opt_state, @@ -1083,6 +1097,14 @@ def _ensure_state_initialized( # FSDP2 support: state tensors must be created from local tensors, not DTensors. # This ensures each rank has state for its local parameter shard. p_local = self._get_local_tensor(p) + if not p_local.is_cuda: + raise ValueError( + "FlashOptim requires parameter shards to be on CUDA at optimizer " + "step time. Detected a CPU shard while initializing optimizer " + "state, which usually means FSDP2 or ZeRO CPU offload is enabled. " + "Disable CPU parameter offload to use FlashOptim, for example " + "set fsdp_offload_params: false." + ) quantize = hparams.get("quantize", self._quantize) for key_quant, spec in self.quantized_state_spec.items(): diff --git a/test/test_fsdp2.py b/test/test_fsdp2.py index 8aecd84..8bb763b 100644 --- a/test/test_fsdp2.py +++ b/test/test_fsdp2.py @@ -1132,3 +1132,132 @@ def test_fsdp2_dcp_training_continuation( str(tmp_path / "ckpt"), seed, ) + + +# ============================================================================ +# FSDP2 uneven-shard state_dict test +# ============================================================================ + + +def _run_fsdp2_uneven_shard_state_dict( + rank: int, + world_size: int, + ckpt_dir: str, + seed: int, +) -> None: + """Unevenly sharded params: global shape not divisible by world_size. + + Verifies that optimizer state tensors are wrapped as DTensors whose global + shape matches the param's global shape — not ``local_size * world_size``, + which ``DTensor.from_local``'s default inference would produce — and that + DCP save + load round-trips exactly. + """ + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.state_dict import ( + get_optimizer_state_dict, + set_optimizer_state_dict, + ) + from torch.distributed.tensor import DTensor + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + # Shapes chosen so no dim divides evenly by world_size=2: + # Linear(10, 7) -> weight [7, 10], bias [7] + # Linear(7, 5) -> weight [5, 7], bias [5] + d_in, d_out, hidden_dim = 10, 5, 7 + dtype = torch.bfloat16 + lr = 0.001 + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + model = _create_simple_model(d_in, d_out, hidden_dim=hidden_dim).to( + device=device, dtype=dtype + ) + dist.barrier() + fully_shard(model) + + # Capture global param shapes for later correctness checks. + param_shapes = {name: tuple(p.shape) for name, p in model.named_parameters()} + + # Sanity: at least one param must be unevenly sharded on its sharded dim. + uneven_seen = any( + p.shape[placement.dim] % p.device_mesh.size(mesh_dim) != 0 + for p in model.parameters() + if isinstance(p, DTensor) + for mesh_dim, placement in enumerate(p.placements) + if hasattr(placement, "dim") + ) + assert uneven_seen, "Test setup failed: no uneven shards produced" + + opt = ADAMW_CONFIG.factory( + model.parameters(), + lr=lr, + compress_state_dict=False, + master_weight_bits=None, + ) + loss_fn = nn.MSELoss() + + g = torch.Generator(device=device).manual_seed(seed + rank) + for _ in range(3): + x = torch.randn(4, d_in, device=device, dtype=dtype, generator=g) + y = torch.randn(4, d_out, device=device, dtype=dtype, generator=g) + loss_fn(model(x), y).backward() + opt.step() + opt.zero_grad(set_to_none=True) + + # --- Verify wrapped DTensors carry the correct global shape --- + saved_osd = get_optimizer_state_dict(model, opt) + for fqn, param_state in saved_osd["state"].items(): + expected_shape = param_shapes[fqn] + for key, val in param_state.items(): + if isinstance(val, DTensor) and val.dim() > 0: + assert tuple(val.shape) == expected_shape, ( + f"[Rank {rank}] state[{fqn}].{key} wrapped with global shape " + f"{tuple(val.shape)}, expected {expected_shape}" + ) + + # --- DCP save + load roundtrip --- + dcp.save({"optimizer": saved_osd}, checkpoint_id=ckpt_dir) + dist.barrier() + + loaded_osd = get_optimizer_state_dict(model, opt) + dcp.load({"optimizer": loaded_osd}, checkpoint_id=ckpt_dir) + + for fqn in saved_osd["state"]: + for key in saved_osd["state"][fqn]: + v_saved = saved_osd["state"][fqn][key] + v_loaded = loaded_osd["state"][fqn][key] + if isinstance(v_saved, torch.Tensor) and v_saved.dim() > 0: + vs = v_saved.to_local() if hasattr(v_saved, "to_local") else v_saved + vl = v_loaded.to_local() if hasattr(v_loaded, "to_local") else v_loaded + assert torch.equal(vs, vl), ( + f"[Rank {rank}] state[{fqn}].{key} changed after " + f"DCP roundtrip: max diff = " + f"{(vs.float() - vl.float()).abs().max().item()}" + ) + + set_optimizer_state_dict(model, opt, loaded_osd) + + x = torch.randn(4, d_in, device=device, dtype=dtype) + y = torch.randn(4, d_out, device=device, dtype=dtype) + loss_fn(model(x), y).backward() + opt.step() + opt.zero_grad(set_to_none=True) + + del model, opt + torch.cuda.empty_cache() + dist.barrier() + + +@pytest.mark.parametrize("seed", [0], ids=lambda s: f"seed{s}") +def test_fsdp2_uneven_shard_state_dict( + seed: int, + fsdp2_runner, + tmp_path, +) -> None: + fsdp2_runner( + _run_fsdp2_uneven_shard_state_dict, + str(tmp_path / "ckpt"), + seed, + ) diff --git a/test/test_serialization.py b/test/test_serialization.py index 4fc42ac..29a869a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -218,6 +218,31 @@ def test_state_dict_save_load( _CKPT_SEEDS = [0, 1] +def test_state_dict_before_first_step_matches_torch_empty_state(): + """Accelerate inspects optimizer state before the first step during prepare().""" + params = [ + torch.nn.Parameter(torch.randn(8, device="cuda", dtype=torch.bfloat16)), + torch.nn.Parameter(torch.randn(4, device="cuda", dtype=torch.bfloat16)), + ] + + opt = ADAMW_CONFIG.factory(params, lr=1e-3) + state_dict = opt.state_dict() + + assert state_dict["state"] == {} + assert len(state_dict["param_groups"]) == 1 + assert state_dict["param_groups"][0]["params"] == [0, 1] + + +def test_cpu_offloaded_param_raises_helpful_error(): + """CPU-offloaded shards should fail early with a targeted message.""" + param = torch.nn.Parameter(torch.randn(16, device="cpu", dtype=torch.bfloat16)) + param.grad = torch.randn_like(param) + + opt = ADAMW_CONFIG.factory([param], lr=1e-3) + with pytest.raises(ValueError, match="requires parameter shards to be on CUDA"): + opt.step() + + @pytest.mark.parametrize("seed", _CKPT_SEEDS, ids=seed_id) @pytest.mark.parametrize("ckpt_config", _CKPT_CONFIGS, ids=ckpt_id) @pytest.mark.parametrize(