Conversation
There was a problem hiding this comment.
Code Review
This pull request significantly enhances sequence parallelism support by implementing ZigZag Ring Attention for long-sequence training and Ulysses-style sequence parallelism for Qwen3.5 linear attention. It also introduces multimodal deepstack patching for Qwen3-VL and refactors the SequenceParallel strategy to better handle complex device meshes and packed/varlen inputs. Feedback focuses on improving code maintainability and robustness, specifically by grouping attributes in the SequenceParallel constructor, removing redundant logic and unused imports, replacing deprecated inspection methods, and centralizing duplicated loss-gathering logic.
| self.seq_world_size = None | ||
| self.sp_world_size = None | ||
| self.rp_world_size = None | ||
| self.dp_world_size = None | ||
| self.world_size = None | ||
| self.attn_implementation = None | ||
| self.model_dtype = None | ||
| self.tokenizer = None | ||
| self.device_mesh = None | ||
| self._sp_group = None | ||
| self._rp_group = None | ||
| self._data_rank_group = None | ||
| self._sp_rank = 0 | ||
| self._rp_rank = 0 | ||
| self.num_heads = None | ||
| self.causal_mask_func = None | ||
| self.extra_kwargs = {} |
| if query.shape[2] != total_tokens: | ||
| raise ValueError('Packed/varlen flash_attention_2 expects query sequence length to match ' | ||
| f'cu_seqlens total tokens, got query_seq_len={query.shape[2]} ' | ||
| f'and cu_seqlens_total={total_tokens}.') |
| if self.rp_world_size > 1: | ||
| attn_impl = getattr(model.config, '_attn_implementation', None) | ||
| if attn_impl != 'flash_attention_2': | ||
| raise NotImplementedError('Derived ring attention only supports flash_attention_2 backend.') |
| @@ -0,0 +1,283 @@ | |||
| import os | |||
| @cache | ||
| def _get_default_args(func): | ||
| spec = inspect.getfullargspec(func) | ||
| defaults = spec.defaults if spec.defaults is not None else () | ||
| padded_defaults = (None, ) * (len(spec.args) - len(defaults)) + defaults | ||
| args = dict(zip(spec.args, padded_defaults)) | ||
| if 'softcap' in args: | ||
| args['softcap'] = 0.0 | ||
| return args |
| if self.sp_strategy is not None: | ||
| loss_inputs, loss_outputs = self.sp_strategy.gather_loss_tensors(inputs, outputs) |
| # local labels still count only the shard-local tokens. Normalize the loss | ||
| # contribution here so metric-side averaging matches the non-SP path. | ||
| if ulysses_size > 1: | ||
| loss = loss / float(ulysses_size) |
There was a problem hiding this comment.
为什么会放到这里呢,或者说,model进行backward的loss是否需要除以ulysses-size
There was a problem hiding this comment.
loss_instance 的reduction 为sum时,这里loss 是 在每个 ulysses rank 上都复制了一份的 全序列 loss,但这里统计的 num_tokens 还是 每个 rank 本地 shard 的 token 数。两边口径不一致,所以要除一次 ulysses_size,这里除一下只是只为修 metric 打印口径;至于反向传播时loss是没有除以ulysses size的,在GatherLoss.apply中只保留了本地梯度
| from twinkle.utils.grad_clip import normalize_and_clip_grad_norm | ||
|
|
||
|
|
||
| def _get_raw_dp_fsdp_world_size(device_mesh: Optional[DeviceMesh]) -> int: |
There was a problem hiding this comment.
这个和device_mesh的dp_world_size似乎是一样的?能否复用
There was a problem hiding this comment.
不一样,这里算的是 dp_world_size * fsdp_world_size,device_mesh的dp_world_size是: @Property
def dp_world_size(self) -> int:
return self._get_world_size_for_dim('dp')
There was a problem hiding this comment.
这个和data_rank是专门用来判断数据组的,而且即使真的增加dpworldsize,实现也应该放在devicemesh里面复用,而非放在transformers.py里面,模型和数据组判断耦合会带来维护问题
| result = loss_instance(inputs, outputs, **kwargs) | ||
| loss_inputs = inputs | ||
| loss_outputs = outputs | ||
| if self.sp_strategy is not None: |
There was a problem hiding this comment.
这部分能否使用inputprocessor?既然切分是inputprocessor做,那gather是否应该也放在里面
There was a problem hiding this comment.
应该不太合适吧,这里已经是到了loss 计算阶段了,inputprocessor的职责应该是做输入的处理的吧
There was a problem hiding this comment.
inputprocessor的名字起的可能不太好,这个组件就是为了做任务相关的数据处理的,放在模型代码里面,如果再增加一个子类,实现不是要重写一遍
- Refactor linear attention sequence parallel import error message into a constantt - Fix token counting in TransformersModel by using raw DP/FSDP world size instead of data_world_size - Enhance Framework.gather_object to check distributed initialization before accessing world size - Add test utility for creating padded labels in sequence parallel tests
- Add `num_tokens` field to `ModelOutput` TypedDict for explicit token denominator - Update `LossOutput` to use `OutputType` for `num_tokens` instead of `int` - Refactor `LossMetric` to prefer `num_tokens` from outputs, with fallback to labels - Remove `_get_raw_dp_fsdp_world_size` helper and use `_device_mesh._get_dp_fsdp_world_size` - Use `InputProcessor.postprocess_tensor_sp` for loss tensor gathering in TransformersModel - Simplify sequence-parallel loss normalization by relying on output `num_tokens`
0e8600d to
f13bc37
Compare




PR type
PR information
This PR adds context parallel and Qwen3.5 Gated DeltaNet sequence parallel support to the transformers stack, and refactors sequence parallel into a package-based implementation.
Main changes:
sequence_parallel.pyintosequence_parallel/and add shared utilities.linear_attention_sp.py;Ring attention is not supported for this path yet.sp_fsdp_dense.tests/moe/test_expert_parallel_qwen3_fsdp_sp.py.Experiment results