Skip to content

[feat] JoyAI-JoyImage-Edit support#13444

Open
Moran232 wants to merge 25 commits intohuggingface:mainfrom
Moran232:joyimage_edit
Open

[feat] JoyAI-JoyImage-Edit support#13444
Moran232 wants to merge 25 commits intohuggingface:mainfrom
Moran232:joyimage_edit

Conversation

@Moran232
Copy link
Copy Markdown

@Moran232 Moran232 commented Apr 10, 2026

Description

We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.

GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430

Model Overview

JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).

Kye Features

  • Advanced Text Rendering Showcase: JoyAI-Image is optimized for challenging text-heavy scenarios, including multi-panel comics, dense multi-line text, multilingual typography, long-form layouts, real-world scene text, and handwritten styles.
  • Multi-view Generation and Spatial Editing Showcase: JoyAI-Image showcases a spatially grounded generation and editing pipeline that supports multi-view generation, geometry-aware transformations, camera control, object rotation, and precise location-specific object editing. Across these settings, it preserves scene content, structure, and visual consistency while following viewpoint-sensitive instructions more accurately.
  • Spatial Editing for Spatial Reasoning Showcase: JoyAI-Image poses high-fidelity spatial editing, serving as a powerful catalyst for enhancing spatial reasoning. Compared with Qwen-Image-Edit and Nano Banana Pro, JoyAI-Image-Edit synthesizes the most diagnostic viewpoints by faithfully executing camera motions. These high-fidelity novel views effectively disambiguate complex spatial relations, providing clearer visual evidence for downstream reasoning.

Image edit examples

spatial-editing-showcase

@github-actions github-actions Bot added models pipelines size/L PR with diff > 200 LOC labels Apr 10, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left some initial feedbacks

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's going on here? is this some legancy code? can we remove?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Comment on lines +371 to +391
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)

txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)

q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)

attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
attn_output, text_attn_output = self.attn(...)

can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)

also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'll clean up this messy code.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in d397b68

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +242 to +250
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor

def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]

Comment on lines +214 to +225
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)

is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in f557113

head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
self.img_mod = JoyImageModulate(...)

let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will refactor modulation and use ModulateWan

tacos8me added a commit to tacos8me/taco-desktop-backend that referenced this pull request Apr 11, 2026
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a
separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline
from the Moran232/diffusers fork + transformers 4.57.1. Process isolation
needed because the fork's diffusers core registry patches cannot be
vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x
is incompatible with our 5.3.0 stack.

Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at
1024² / 30 steps (well under the 80 GB gate). Passed.

- `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call
  short-lived AsyncClient, split timeouts (180s edit / 60s mgmt),
  HTTPStatus→JoyAIError mapping. Singleton `joyai` exported.
- `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and
  `LOAD_JOYAI` env flag. Off by default.
- `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4
  helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)`
  helper. All three `_ensure_*_ready()` helpers are now `async def` —
  13 call sites updated across _dispatch_job and v1 sync handlers.
  IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client;
  validates len(image_paths)==1 (422 otherwise). Lifespan health-probes
  the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503
  if unreachable).
- `flux_manager.py`: pre-existing bug fix — _edit() hardcoded
  ensure_model("flux2-klein"), silently ignoring the dispatcher's
  `model` kwarg. Now accepts and respects `model`. Guidance_scale
  is now conditional on model != "flux2-klein" (Klein strips CFG,
  Dev uses it).
- `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py`
  (+3 tests): 89 tests passing (was 79).
- Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all
  updated with joyai-edit model entry, three-tenant swap diagram,
  latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8
  changelog entry.

Out-of-tree (not committed here, installed separately):
  /mnt/nvme-1/servers/joyai-sidecar/     (sidecar venv + sidecar.py + run.sh)
  ~/.config/systemd/user/joyai-sidecar.service

Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit →
SSE stream (phase denoising → encoding → None) → fetch WEBP result
(352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap
evicted LTX and reloaded it cleanly via _evict_other_tenants.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +454 to +459
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.

Copy link
Copy Markdown
Author

@Moran232 Moran232 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete this repaxxx, see f557113

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the repa logic removed because it is not used in inference?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment on lines +900 to +901
timesteps: List[int] = None,
sigmas: List[float] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps: List[int] = None,
sigmas: List[float] = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,

nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@Moran232
Copy link
Copy Markdown
Author

@yiyixuxu @dg845
Thank you very much for your valuable feedback. I've made some modifications. See my latest commits.

Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code.

If you have any further suggestions, please feel free to share. Thank you so much!

# ---- joint attention (fused QKV, directly on the block) ----
# image attention layers
self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True)
self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread setup.py Outdated
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 28, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 28, 2026

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
# ---------------------------------------------------------------------------


@maybe_allow_in_graph
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@maybe_allow_in_graph

I think we can remove @maybe_allow_in_graph here as the transformer torch.compile tests in TestJoyImageEditTransformerCompile will pass even without it.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove in f364da3

Comment on lines +471 to +474
hidden_states = hidden_states + _apply_gate(
self.img_mlp(_modulate(self.img_norm2(hidden_states), img_mod2_shift, img_mod2_scale)),
img_mod2_gate,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more clear if we inlined the _modulate and _apply_gate helpers. We generally prefer not to have too many small helper functions.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix in f364da3.


_skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"]
_no_split_modules = ["JoyImageTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_repeated_blocks = ["JoyImageTransformerBlock"]

We should define _repeated_blocks so that regional compilation is supported. This also allows the TestJoyImageEditTransformerCompile.test_torch_compile_repeated_blocks test to pass.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix in e45e1ad.

return Transformer2DModelOutput(sample=img)


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think removing this alias and using one name (for example, JoyImageTransformer3DModel) everywhere would be more clear.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#13444 (comment)

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Due to the relationship between JoyImage and JoyImage-Edit, we need to reserve class designs for the JoyImage models (will be released in the future) to avoid potential confusion.

from .pipeline_output import JoyImageEditPipelineOutput


EXAMPLE_DOC_STRING = """"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an example here (for example, based on https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers#running-with-diffusers)? See e.g. WanPipeline for an example:

EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import AutoencoderKLWan, WanPipeline
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, add an example in aeaa334

Comment on lines +100 to +106
@dataclass
class _LegacyPipelineOutput(BaseOutput):
"""Legacy output dataclass retained for backward compatibility."""

videos: Union[torch.Tensor, np.ndarray]


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@dataclass
class _LegacyPipelineOutput(BaseOutput):
"""Legacy output dataclass retained for backward compatibility."""
videos: Union[torch.Tensor, np.ndarray]

I think we can remove _LegacyPipelineOutput as it doesn't appear to be used anywhere.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in f364da3.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! Can you run make fix-copies to make dummy objects for the new transformer and pipeline? This will help the CI pass.

Also, if I try to run the pipeline using the following script:

import torch
from diffusers import JoyImageEditPipeline
from diffusers.utils import load_image

pipeline = JoyImageEditPipeline.from_pretrained(
    "jdopensource/JoyAI-Image-Edit-Diffusers", torch_dtype=torch.bfloat16
)
pipeline.to("cuda")

img_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image = load_image(img_path)

prompt = "Add wings to the astronaut."
prompts = [f"<|im_start|>user\n<image>\n{prompt}<|im_end|>\n"]

image = pipeline(
    image=image,
    prompt=prompt,
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
).images[0]

image.save("joyai_image_edit_output.png")

it appears the transformer doesn't load correctly:

Traceback (most recent call last):
  File "~/diffusers/scripts/joyimage_edit_test.py", line 6, in <module>
    pipeline = JoyImageEditPipeline.from_pretrained(
  ...
  File "~/diffusers/src/diffusers/models/transformers/transformer_joyimage.py", line 525, in __init__
    raise ValueError(
ValueError: hidden_size (4096) must be divisible by num_attention_heads (24)

Can you look into this?

@feice-huang
Copy link
Copy Markdown

it appears the transformer doesn't load correctly:

Traceback (most recent call last):
  File "~/diffusers/scripts/joyimage_edit_test.py", line 6, in <module>
    pipeline = JoyImageEditPipeline.from_pretrained(
  ...
  File "~/diffusers/src/diffusers/models/transformers/transformer_joyimage.py", line 525, in __init__
    raise ValueError(
ValueError: hidden_size (4096) must be divisible by num_attention_heads (24)

Can you look into this?

@dg845
We removed the unused parameter 'heads_num' from the transformer and modified the default value of 'num_attention_heads' in 3ed6ca9, which caused changes in JoyAI-Image-Edit-Diffusers/transformer/config.json.

We have updated this config.json, and now it would load correctly.

@github-actions github-actions Bot added utils size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@feice-huang
Copy link
Copy Markdown

@dg845 @yiyixuxu
Thank you very much for your feedback! I’ve learned a lot from your suggestions. We’ve implemented a new round of revisions based on your comments. Please feel free to share any further suggestions, thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for JoyAI-Image-Edit

5 participants