Skip to content

Add attention mask input support for flash backend#13479

Open
zhtmike wants to merge 2 commits intohuggingface:mainfrom
zhtmike:flash-attn-mask
Open

Add attention mask input support for flash backend#13479
zhtmike wants to merge 2 commits intohuggingface:mainfrom
zhtmike:flash-attn-mask

Conversation

@zhtmike
Copy link
Copy Markdown
Contributor

@zhtmike zhtmike commented Apr 15, 2026

What does this PR do?

This PR adds support for attention mask input when using the attention backend with set_attention_backend("flash"). With this change, QwenImagePipeline can run with the flash backend w/ or w/o Ulysses SP.

For FlashAttention 2, it is not feasible to use _wrapped_flash_attn_forward directly when a mask is applied. To maintain compatibility with the current interface, we introduce an additional branch for FlashAttention to handle attention masks.

# forward pass
-. w/o mask: _wrapped_flash_attn_forward()
-. w/ mask (new): _pack_qkv() --> _wrapped_flash_attn_varlen_forward() --> unpack()
# backward pass
-. w/o mask: stored tensor ->  _wrapped_flash_attn_forward()
-. w/ mask (new): stored packed tensor -> _wrapped_flash_attn_varlen_backward() -> unpack()

I haven't tested with ring attention, so it is left as unimplemented.

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot added models tests size/L PR with diff > 200 LOC labels Apr 15, 2026
@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented Apr 15, 2026

code snippet to show it works

import torch
import torch.distributed as dist
import argparse
import os
from diffusers import QwenImagePipeline
from diffusers import ContextParallelConfig


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test Qwen-Image with Context Parallelism")
    return parser.parse_args()


args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

model_id = os.path.expanduser("~/models/Qwen/Qwen-Image")

pipe = QwenImagePipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
)
pipe.to(device)

pipe.transformer.set_attention_backend("flash")   # <--------- here 
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(config=ContextParallelConfig(
        ulysses_degree=world_size))

pipe.set_progress_bar_config(disable=rank != 0)

positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}
prompts = [
    "A coffee shop entrance features a chalkboard sign reading "
    '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
    'displaying "通义千问". Next to it hangs a poster showing a '
    "beautiful Chinese woman, and beneath the poster is written "
    '"π≈3.1415926-53589793-23846264-33832795-02384197". '
    "Ultra HD, 4K, cinematic composition",
    "A cute cat with long hair sitting on a sofa, Ultra HD, 4K, cinematic composition."
]

inputs = {
    "prompt": [p + positive_magic["en"] for p in prompts],
    "generator": torch.Generator(device="cpu").manual_seed(0),
    "true_cfg_scale": 4.0,
    "negative_prompt": " ",
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
    "height": 1024,
    "width": 1024,
}

with torch.inference_mode():
    output = pipe(**inputs)
    for i, output_image in enumerate(output.images):
        if world_size > 1:
            save_path = f"output_image_ulysses{world_size}_{i}.png"
        else:
            save_path = f"output_image_{i}.png"
        if rank == 0:
            output_image.save(save_path)
            print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

Produces the following images:
output_image_0
output_image_1

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant