Skip to content
54 changes: 38 additions & 16 deletions flashoptim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,28 +636,42 @@ def _wrap_state_as_dtensor(state: dict[str, Any], param: torch.Tensor) -> None:

mesh = param.device_mesh
placements = param.placements
param_local_shape = param.to_local().shape

# Reject uneven shards - DTensor.from_local infers global shape as
# local_size * world_size, which is only correct for even splits.
for mesh_dim, placement in enumerate(placements):
if hasattr(placement, "dim"):
shard_dim = placement.dim
mesh_size = mesh.size(mesh_dim)
if param.shape[shard_dim] % mesh_size != 0:
raise ValueError(
f"DCP checkpointing requires evenly-sharded parameters, "
f"but parameter with shape {param.shape} is unevenly "
f"sharded on dim {shard_dim} across {mesh_size} ranks. "
f"Pad or reshape the parameter so that shape[{shard_dim}] "
f"is divisible by {mesh_size}."
)
uneven = any(
param.shape[p.dim] % mesh.size(i) != 0
for i, p in enumerate(placements)
if hasattr(p, "dim")
)

for key, val in state.items():
if (
if not (
isinstance(val, torch.Tensor)
and not isinstance(val, DTensor)
and val.dim() > 0
):
continue
if val.shape == param_local_shape:
# Pass explicit shape/stride so wrapping is correct on uneven shards.
state[key] = DTensor.from_local(
val,
mesh,
placements,
shape=param.shape,
stride=param.stride(),
run_check=False,
)
elif uneven:
# e.g. quantized state with a packed shape: default inference
# would scramble DCP on uneven shards, and we have no global
# shape for this layout. Fail loudly rather than silently.
raise ValueError(
f"Cannot safely wrap state tensor {key!r} of shape "
f"{tuple(val.shape)} as a DTensor for unevenly-sharded "
f"param of shape {tuple(param.shape)}: its layout does "
f"not match the param, so the global shape is unknown."
)
else:
state[key] = DTensor.from_local(val, mesh, placements)

def _recompute_stats_for_param(self, p: torch.Tensor) -> None:
Expand Down Expand Up @@ -988,7 +1002,7 @@ def state_dict(self):
for param_number in group["params"]:
assert isinstance(param_number, int)
if param_number not in opt_state:
continue # frozen param, no optimizer state
continue # frozen or pre-first-step param, no optimizer state
opt_state[param_number] = self._state_dict_for_param(
param_number,
opt_state=opt_state,
Expand Down Expand Up @@ -1083,6 +1097,14 @@ def _ensure_state_initialized(
# FSDP2 support: state tensors must be created from local tensors, not DTensors.
# This ensures each rank has state for its local parameter shard.
p_local = self._get_local_tensor(p)
if not p_local.is_cuda:
raise ValueError(
"FlashOptim requires parameter shards to be on CUDA at optimizer "
"step time. Detected a CPU shard while initializing optimizer "
"state, which usually means FSDP2 or ZeRO CPU offload is enabled. "
"Disable CPU parameter offload to use FlashOptim, for example "
"set fsdp_offload_params: false."
)

quantize = hparams.get("quantize", self._quantize)
for key_quant, spec in self.quantized_state_spec.items():
Expand Down
129 changes: 129 additions & 0 deletions test/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,3 +1132,132 @@ def test_fsdp2_dcp_training_continuation(
str(tmp_path / "ckpt"),
seed,
)


# ============================================================================
# FSDP2 uneven-shard state_dict test
# ============================================================================


def _run_fsdp2_uneven_shard_state_dict(
rank: int,
world_size: int,
ckpt_dir: str,
seed: int,
) -> None:
"""Unevenly sharded params: global shape not divisible by world_size.

Verifies that optimizer state tensors are wrapped as DTensors whose global
shape matches the param's global shape — not ``local_size * world_size``,
which ``DTensor.from_local``'s default inference would produce — and that
DCP save + load round-trips exactly.
"""
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.tensor import DTensor

device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)

# Shapes chosen so no dim divides evenly by world_size=2:
# Linear(10, 7) -> weight [7, 10], bias [7]
# Linear(7, 5) -> weight [5, 7], bias [5]
d_in, d_out, hidden_dim = 10, 5, 7
dtype = torch.bfloat16
lr = 0.001

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model = _create_simple_model(d_in, d_out, hidden_dim=hidden_dim).to(
device=device, dtype=dtype
)
dist.barrier()
fully_shard(model)

# Capture global param shapes for later correctness checks.
param_shapes = {name: tuple(p.shape) for name, p in model.named_parameters()}

# Sanity: at least one param must be unevenly sharded on its sharded dim.
uneven_seen = any(
p.shape[placement.dim] % p.device_mesh.size(mesh_dim) != 0
for p in model.parameters()
if isinstance(p, DTensor)
for mesh_dim, placement in enumerate(p.placements)
if hasattr(placement, "dim")
)
assert uneven_seen, "Test setup failed: no uneven shards produced"

opt = ADAMW_CONFIG.factory(
model.parameters(),
lr=lr,
compress_state_dict=False,
master_weight_bits=None,
)
loss_fn = nn.MSELoss()

g = torch.Generator(device=device).manual_seed(seed + rank)
for _ in range(3):
x = torch.randn(4, d_in, device=device, dtype=dtype, generator=g)
y = torch.randn(4, d_out, device=device, dtype=dtype, generator=g)
loss_fn(model(x), y).backward()
opt.step()
opt.zero_grad(set_to_none=True)

# --- Verify wrapped DTensors carry the correct global shape ---
saved_osd = get_optimizer_state_dict(model, opt)
for fqn, param_state in saved_osd["state"].items():
expected_shape = param_shapes[fqn]
for key, val in param_state.items():
if isinstance(val, DTensor) and val.dim() > 0:
assert tuple(val.shape) == expected_shape, (
f"[Rank {rank}] state[{fqn}].{key} wrapped with global shape "
f"{tuple(val.shape)}, expected {expected_shape}"
)

# --- DCP save + load roundtrip ---
dcp.save({"optimizer": saved_osd}, checkpoint_id=ckpt_dir)
dist.barrier()

loaded_osd = get_optimizer_state_dict(model, opt)
dcp.load({"optimizer": loaded_osd}, checkpoint_id=ckpt_dir)

for fqn in saved_osd["state"]:
for key in saved_osd["state"][fqn]:
v_saved = saved_osd["state"][fqn][key]
v_loaded = loaded_osd["state"][fqn][key]
if isinstance(v_saved, torch.Tensor) and v_saved.dim() > 0:
vs = v_saved.to_local() if hasattr(v_saved, "to_local") else v_saved
vl = v_loaded.to_local() if hasattr(v_loaded, "to_local") else v_loaded
assert torch.equal(vs, vl), (
f"[Rank {rank}] state[{fqn}].{key} changed after "
f"DCP roundtrip: max diff = "
f"{(vs.float() - vl.float()).abs().max().item()}"
)

set_optimizer_state_dict(model, opt, loaded_osd)

x = torch.randn(4, d_in, device=device, dtype=dtype)
y = torch.randn(4, d_out, device=device, dtype=dtype)
loss_fn(model(x), y).backward()
opt.step()
opt.zero_grad(set_to_none=True)

del model, opt
torch.cuda.empty_cache()
dist.barrier()


@pytest.mark.parametrize("seed", [0], ids=lambda s: f"seed{s}")
def test_fsdp2_uneven_shard_state_dict(
seed: int,
fsdp2_runner,
tmp_path,
) -> None:
fsdp2_runner(
_run_fsdp2_uneven_shard_state_dict,
str(tmp_path / "ckpt"),
seed,
)
25 changes: 25 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,31 @@ def test_state_dict_save_load(
_CKPT_SEEDS = [0, 1]


def test_state_dict_before_first_step_matches_torch_empty_state():
"""Accelerate inspects optimizer state before the first step during prepare()."""
params = [
torch.nn.Parameter(torch.randn(8, device="cuda", dtype=torch.bfloat16)),
torch.nn.Parameter(torch.randn(4, device="cuda", dtype=torch.bfloat16)),
]

opt = ADAMW_CONFIG.factory(params, lr=1e-3)
state_dict = opt.state_dict()

assert state_dict["state"] == {}
assert len(state_dict["param_groups"]) == 1
assert state_dict["param_groups"][0]["params"] == [0, 1]


def test_cpu_offloaded_param_raises_helpful_error():
"""CPU-offloaded shards should fail early with a targeted message."""
param = torch.nn.Parameter(torch.randn(16, device="cpu", dtype=torch.bfloat16))
param.grad = torch.randn_like(param)

opt = ADAMW_CONFIG.factory([param], lr=1e-3)
with pytest.raises(ValueError, match="requires parameter shards to be on CUDA"):
opt.step()


@pytest.mark.parametrize("seed", _CKPT_SEEDS, ids=seed_id)
@pytest.mark.parametrize("ckpt_config", _CKPT_CONFIGS, ids=ckpt_id)
@pytest.mark.parametrize(
Expand Down