Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| ckpt_id = "black-forest-labs/FLUX.1-dev" | ||
|
|
||
| # --- Text encoding (CPU) --- | ||
| prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side" |
There was a problem hiding this comment.
nit: again, for clarity I would avoid the "Trillium" word if we test on v5.
There was a problem hiding this comment.
This is probably fine. It's quite separate.
| xs.mark_sharding(param, mesh, tuple(spec)) | ||
|
|
||
| flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) | ||
| FlashAttention.DEFAULT_BLOCK_SIZES = { |
There was a problem hiding this comment.
this looks like black magic, consider adding a comment explaining where these come from
There was a problem hiding this comment.
Cc: @entrpn for those as it's copied from flux_inference.py.
There was a problem hiding this comment.
If I remember correctly, these block sizes have been optimized for Trillium through some tests we ran internally. They can be kept as is as long the v5e's vmem can handle it. These could be optimized in the future specifically for v5e.
|
|
||
| def _vae_decode(latents, vae, height, width, device): | ||
| """Move VAE to XLA, decode latents, move VAE back to CPU.""" | ||
| vae.to(device) |
There was a problem hiding this comment.
I do not know much about this, but isn't moving VAE back and forth between xla device and cpu quite expensive in time? Woudn't it be better just to keep it in XLA?
There was a problem hiding this comment.
It would barely fit, otherwise. Plus we have to free some stuff anyway to do the actual computation. Once it's compiled it doesn't take much of a hit barring some displacement overhead which is likely justifiable given the cheap pricing of v5es. Does it make sense?
| 2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec. | ||
| ``` | ||
|
|
||
| The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s). |
There was a problem hiding this comment.
Perhaps you can include a dummy inference in the compilation part, so that VAE is compiled and timings look more regular.
There was a problem hiding this comment.
I didn't get this part. Elaborate more? The block under "logger.info("starting compilation run...")" has the VAE compilation included.
|
|
||
| def _vae_decode(latents, vae, height, width, device): | ||
| """Move VAE to XLA, decode latents, move VAE back to CPU.""" | ||
| vae.to(device) |
There was a problem hiding this comment.
It would barely fit, otherwise. Plus we have to free some stuff anyway to do the actual computation. Once it's compiled it doesn't take much of a hit barring some displacement overhead which is likely justifiable given the cheap pricing of v5es. Does it make sense?
| ckpt_id = "black-forest-labs/FLUX.1-dev" | ||
|
|
||
| # --- Text encoding (CPU) --- | ||
| prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side" |
There was a problem hiding this comment.
This is probably fine. It's quite separate.
| xs.mark_sharding(param, mesh, tuple(spec)) | ||
|
|
||
| flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) | ||
| FlashAttention.DEFAULT_BLOCK_SIZES = { |
There was a problem hiding this comment.
Cc: @entrpn for those as it's copied from flux_inference.py.
| 2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec. | ||
| ``` | ||
|
|
||
| The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s). |
There was a problem hiding this comment.
I didn't get this part. Elaborate more? The block under "logger.info("starting compilation run...")" has the VAE compilation included.
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
|
Additionally, @entrpn I am seeing recompilations with the following Is that expected? |
| xs.mark_sharding(param, mesh, tuple(spec)) | ||
|
|
||
| flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) | ||
| FlashAttention.DEFAULT_BLOCK_SIZES = { |
There was a problem hiding this comment.
If I remember correctly, these block sizes have been optimized for Trillium through some tests we ran internally. They can be kept as is as long the v5e's vmem can handle it. These could be optimized in the future specifically for v5e.
|
@tengomucho a gentle ping |
What does this PR do?
Add an example of model parallelism for Flux using PyTorch XLA. Tested on v5e-8.
Cc: @entrpn if you could review.