From b0b8008fa979afc5ff4b8e41528bf334665731ae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 15 Apr 2026 15:36:06 +0530 Subject: [PATCH] fix autoencoderkl qwenimage for xla --- .../autoencoders/autoencoder_kl_qwenimage.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index f52071bf470b..eb45c3c7ee3c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -180,7 +180,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_cache[idx] = "Rep" feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk cache_x = torch.cat( @@ -258,7 +258,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) @@ -277,7 +277,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) @@ -446,7 +446,7 @@ def __init__( def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) @@ -471,7 +471,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): x = self.nonlinearity(x) if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) @@ -636,7 +636,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) @@ -658,7 +658,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): x = self.nonlinearity(x) if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -min(CACHE_T, x.shape[2]) :, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)