STORYDIFFUSION: CONSISTENT SELF-ATTENTION FOR LONG-RANGE IMAGE AND VIDEO GENERATION

核心问题是什么?

对于最近基于扩散的生成模型,在一系列生成的图像中保持一致的内容,特别是那些包含主题和复杂细节的图像,提出了重大挑战。

本文方法

  1. 一致性自注意力:一种新的自注意力计算方法,它显着提高了生成图像之间的一致性,并以零样本方式增强了流行的基于预训练扩散的文本到图像模型。
  2. 语义运动预测器:一种新颖的语义空间时间运动预测模块。它被训练来估计语义空间中两个提供的图像之间的运动条件。该模块将生成的图像序列转换为具有平滑过渡和一致主题的视频,比仅基于潜在空间的模块更加稳定,特别是在长视频生成的情况下。

效果

通过合并这两个新颖的组件,我们的框架(称为 StoryDiffusion)可以描述基于文本的故事,其中包含包含丰富内容的一致图像或视频。

核心贡献是什么?

  1. 一致性自注意力(Consistent Self-Attention):提出了一种新的自注意力计算方法,能够在生成图像的过程中显著提高内容的一致性,例如保持角色身份和服饰的一致性。这种方法不需要额外的训练,可以以零样本(zero-shot)的方式插入到现有的扩散模型中。

  2. 语义运动预测器(Semantic Motion Predictor):为了将一系列图像转换为具有平滑过渡的视频,提出了一种新的模块,它在语义空间中估计两张图像之间的运动条件。这种方法在长视频生成中特别有效,能够生成比仅基于潜在空间的方法更稳定的视频。

  3. StoryDiffusion框架:结合了上述两个新组件,能够根据文本描述生成一致性的图像或视频,涵盖了丰富多样的内容。这种方法在视觉故事生成方面进行了开创性的探索。

  4. 无需训练的一致性图像生成:StoryDiffusion能够在不需要训练的情况下,通过分割文本故事并使用提示生成一系列高度一致的图像,这些图像能够有效地叙述一个故事。

  5. 滑动窗口技术:为了支持长故事的生成,StoryDiffusion实现了与滑动窗口相结合的一致性自注意力,这消除了峰值内存消耗对输入文本长度的依赖,使得长故事的生成成为可能。

  6. 预训练的运动模块:结合预训练的运动模块,Semantic Motion Predictor能够生成比现有条件视频生成方法(如SEINE和SparseCtrl)更平滑、更稳定的视频帧。

大致方法是什么?

第一步:根据描述生成一系列一致的图像,作为关键帧
第二步:在关键帧之间生成中间动作,补全成视频

免训练一致图像生成

StoryDiffusion 生成主题一致图像的流程。为了创建主题一致的图像来描述故事,我们将一致的自我注意力融入到预先训练的文本到图像的扩散模型中。我们将故事文本拆分为多个提示,并使用这些提示批量生成图像。一致的自注意力在批量的多个图像之间建立连接,以实现主题一致性。

[?] 这个图画得不对?代码上还有input_id_images作为输入

pipeline

# 简化代码,保留关键过程,源码请查看github
class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
    @torch.no_grad()
    def __call__(...):
        # 0. Default height and width to unet
        ...
        # 1. Check inputs. Raise error if not correct
        ...
        # 2. Define call parameters
        ...
        # 3. Encode input prompt,同时将reference image注入到embedding中
        for prompt in prompt_arr:
            # 3.1 Encode input prompt with trigger world
            ...
            # 3.2 Encode input prompt without the trigger word for delayed conditioning
            # 分别生成带trigger world的embedding和不带trigger world的embedding是训练策略。先保证无trigger world的普通生成质量,再加入trigger world。
            ...
            # 5. Prepare the input ID images
            ...
        # 7. Prepare timesteps
        ...
        # 8. Prepare latent variables,latents的状态会累积
        latents = self.prepare_latents(
                ...
                latents, # init latents可以为None
            )
        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        ...
        # 10. Prepare added time ids & embeddings
        ...
        # 11. Denoising loop
        ...
        # 12. decoder and get image
        ...
        return image

使用Textual Inversion把Reference Image注入到文本中

# 简化代码,保留关键过程,源码请查看github
# 3. Encode input prompt
for prompt in prompt_arr:
    # 3.1 Encode input prompt with trigger world
    (
        prompt_embeds, # 记录所有prompt的embeds
        pooled_prompt_embeds, # 记录当前prompt的embeds
        class_tokens_mask,
    ) = self.encode_prompt_with_trigger_word(
        prompt=prompt,
        prompt_2=prompt_2,
        nc_flag = nc_flag,
        ...
    )
    # 3.2 Encode input prompt without the trigger word for delayed conditioning
    # 先生成不带trigger world的prompt
    # encode, 此处的encode是prompt转为token的意思,与上下文中的Encode不同
    tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False)
    # remove trigger word token
    trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word)
    if not nc_flag:
        tokens_text_only.remove(trigger_word_token)
    # then decode, token -> prompt
    prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False)
    # 再Encode
    ...
    # 5. Prepare the input ID images
    ...
    if not nc_flag:
        # 6. Get the update text embedding with the stacked ID embedding
        prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # [B, S, D] -> [B, S*N, D]
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        # [B, S*N, D] -> [B*N, S, D] --- 这个直接repeat(N,1,1)有什么区别?
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
            bs_embed * num_images_per_prompt, -1
        )
        pooled_prompt_embeds_arr.append(pooled_prompt_embeds)
        pooled_prompt_embeds = None
第一步:对input prompt进行Encode

对input prompt进行Encode。其中prompt中是否包含trigger world token没有本质区别,只是一种训练策略。Encode的过程包括tokenize和text encode。

prompt --(tokenize)--> token --(text encode)--> embedding。  

其中tokenize的过程有一些特殊处理,过程如下:

输入输出操作
world listtoken listtokenizer.encode
token listclass_token_index, clean_input_ids listtoken list中与trigger world token不同的token被放入clean input ids中,与trigger world token相同的token则被丢弃。
作者认为trigger world代表reference image,是一个名词,那么它前面的词就是形容reference image的特征的关键词,代码里称其为class。这个关键词在clean input ids中的index被记录到class_token_index list中。
实际上,只允许trigger world token在prmopt token中出现一次,因此也只有关键词及其在clean input ids中的index。
class_token_index, clean input ids = [token, token, class, token, ...], reference image的数量clean input ids = [token, token, class,class, class, token, ...]根据reference image的数量重复class token
clean input idsclean input ids把clean input ids补充或截断到固定长度
clean input idsclass_tokens_mask标记clean input ids中哪些是class
clean input idsprompt_embeds对每一个token逐个进行embedding并concat

具体代码如下:

def encode_prompt_with_trigger_word(
    self,
    prompt: str,
    prompt_2: Optional[str] = None,
    num_id_images: int = 1,
    device: Optional[torch.device] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    class_tokens_mask: Optional[torch.LongTensor] = None,
    nc_flag: bool = False,
):
    ...
    # Find the token id of the trigger word
    image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word)

    # Define tokenizers and text encoders
    ...

    if prompt_embeds is None:
        ...
        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
            input_ids = tokenizer.encode(prompt)
            # Find out the corresponding class word token based on the newly added trigger word token
            for i, token_id in enumerate(input_ids):
                if token_id == image_token_id:
                    class_token_index.append(clean_index - 1)
                else:
                    clean_input_ids.append(token_id)
                    clean_index += 1
            # 异常处理
            ...

            class_token_index = class_token_index[0]

            # Expand the class word token and corresponding mask
            class_token = clean_input_ids[class_token_index]
            clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \
                clean_input_ids[class_token_index+1:]

            # Truncation or padding
            max_len = tokenizer.model_max_length
            if len(clean_input_ids) > max_len:
                clean_input_ids = clean_input_ids[:max_len]
            else:
                clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
                    max_len - len(clean_input_ids)
                )

            class_tokens_mask = [True if class_token_index <= i < class_token_index+num_id_images else False \
                    for i in range(len(clean_input_ids))]

            # 维度统一
            ...

            prompt_embeds = text_encoder(
                clean_input_ids.to(device),
                output_hidden_states=True,
            )

            # We are only ALWAYS interested in the pooled output of the final text encoder
            pooled_prompt_embeds = prompt_embeds[0]
            prompt_embeds = prompt_embeds.hidden_states[-2]
            prompt_embeds_list.append(prompt_embeds)

        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

    prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
    class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case

    return prompt_embeds, pooled_prompt_embeds, class_tokens_mask
第二步:把reference image与prompt融合

先对每个reference image依次编码,然后让reference image embedding与prompt embedding中标记为class的embedding做融合。融合过程为MLP。

具体代码如下:

class FuseModule(nn.Module):
    def __init__(self, embed_dim):
        ...

    def fuse_fn(self, prompt_embeds, id_embeds):
        stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
        stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
        stacked_id_embeds = self.mlp2(stacked_id_embeds)
        stacked_id_embeds = self.layer_norm(stacked_id_embeds)
        return stacked_id_embeds

    def forward(self, prompt_embeds, id_embeds, class_tokens_mask, ) -> torch.Tensor:
        # id_embeds shape: [b, max_num_inputs, 1, 2048]
        id_embeds = id_embeds.to(prompt_embeds.dtype)
        # 维度匹配
        ...
        valid_id_embeds = ...

        # slice out the image token embeddings
        image_token_embeds = prompt_embeds[class_tokens_mask]
        stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
        ...
        return updated_prompt_embeds

denoise step

denoise step 使用 UNet-based diffusion network + CFG训练策略,输入由以下方式构成:

  • latent_model_input
    • latents
    • latents
  • current_prompt_embeds
    • negative_prompt_embeds
    • prompt_embeds(text only)
      • id pixel values
      • prompt embeddings
      • class token mask
  • added_cond_kwargs
    • add text embeddings
      • negative_pooled_prompt_embeds
      • pooled_prompt_embeds(text_only)
    • add time embeddings
# 11. Denoising loop
for i, t in enumerate(timesteps):
    latent_model_input = (
        torch.cat([latents] * 2) if do_classifier_free_guidance else latents
    )
    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

    if i <= start_merge_step or nc_flag:  
        current_prompt_embeds = torch.cat(
            [negative_prompt_embeds, prompt_embeds_text_only], dim=0
        )
        add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
    else:
        current_prompt_embeds = torch.cat(
            [negative_prompt_embeds, prompt_embeds], dim=0
        )
        add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
    # predict the noise residual
    added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
    noise_pred = self.unet(
        latent_model_input,
        t,
        encoder_hidden_states=current_prompt_embeds,
        cross_attention_kwargs=cross_attention_kwargs,
        added_cond_kwargs=added_cond_kwargs,
        return_dict=False,
    )[0]
    # perform guidance
    if do_classifier_free_guidance:
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    if do_classifier_free_guidance and guidance_rescale > 0.0:
        # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
        noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

    # compute the previous noisy sample x_t -> x_t-1
    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
    ...
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
return image

图像一致性问题

在生成过程中在batch内的图像之间建立连接。保持一批图像中角色的一致性。
方法:将一致性自注意力插入到图像生成模型现有 U-Net 架构中原始自注意力的位置,并重用原始自注意力权重以保持免训练和可插拔。

定义一批图像特征为: \(\mathcal{I} ∈ R^{B×N×C}\) ,其中 B、N 和 C 分别是batch size、每个图像中的token数量和channel数。
通常情况下,第i张图像的Attention函数的输入xQ、xK、xV由第i图像的特征(1×N×C)通过映射得到。

本文为了在batch中的图像之间建立交互以保持一致性,修改为从batch中的其他图像特征中采样一些token加入第i个图像的特征中,第i张图像的特征变为1×(W * N * sampling_rate + N)×C,其中第一部分为从其它图像采样来的token,第二部分为自己原有的token。

def ConsistentSelfAttention(images_features, sampling_rate, tile_size): 
    """ 
    images_tokens: [B, C, N] # 论文上是这么写的,但我认为是[B, N, C] 
    sampling_rate: Float (0-1) 
    tile_size: Int 
    """ 
    output = zeros(B, N, C), count = zeros(B, N, C), W = tile_size 
    for t in range(0, N - tile_size + 1): 
        # Use tile to override out of GPU memory 
        tile_features = images_tokens[t:t + W, :, :] 
        reshape_featrue = tile_feature.reshape(1, W*N, C).repeat(W, 1, 1) 
        sampled_tokens = RandSample(reshape_featrue, rate=sampling_rate, dim=1) 
        # Concat the tokens from other images with the original tokens 
        token_KV = concat([sampled_tokens, tile_features], dim=1) 
        token_Q = tile_features 
        # perform attention calculation: 
        X_q, X_k, X_v = Linear_q(token_Q), Linear_k(token_KV), Linear_v(token_KV) 
        output[t:t+w, :, :] += Attention(X_q, X_k, X_v) 
        count[t:t+w, :, :] += 1 
    output = output/count 
    return output

用于视频生成的语义运动预测器

生成过渡视频的方法的流程。为了有效地建模角色的大动作,我们将条件图像编码到图像语义空间中,以编码空间信息并预测过渡嵌入。然后使用视频生成模型对这些预测的嵌入进行解码,嵌入作为交叉注意力中的控制信号来指导每帧的生成。

任务描述:通过在每对相邻图像之间插入帧,可以将生成的字符一致图像的序列进一步细化为视频。这可以看作是一个以已知开始帧和结束帧为条件的视频生成任务

主要挑战:当两​​幅图像之间的差异较大时,SparseCtrl (Guo et al., 2023) 和 SEINE (Chen et al., 2023) 等最新方法无法稳定地连接两个条件图像。

当前解决方法的问题:这种限制源于它们仅依赖时间模块来预测中间帧,这可能不足以处理图像对之间的大状态间隙。时间模块在每个空间位置上的像素内独立操作,因此,在推断中间帧时可能没有充分考虑空间信息。这使得对长距离且具有物理意义的运动进行建模变得困难

本文解决方法:语义运动预测器它将图像编码到图像语义空间中以捕获空间信息,从给定的起始帧和结束帧实现更准确的运动预测。

  1. 引入CLIP作为图像编码器,利用其零样本的能力,建立从 RGB 图像到图像语义空间中的向量的映射,对空间信息进行编码。
  2. 使用CLIP 图像编码器,将给定的起始帧Fs和结束帧Fe压缩为图像语义空间向量Ks、Ke。
  3. 线性插值,将 Ks 和 Ke 扩展为序列 K1, K2, ..., KL,其中 L 是所需的视频长度。
  4. 在图像语义空间中,训练transformer-based structure predictor来预测每个中间帧。predictor的输入是序列 K1, K2, ..., KL。输出是过渡帧。
  5. 将这些图像语义嵌入 P1、P2、...、PL 定位为控制信号,将视频扩散模型定位为解码器,将图像语义空间中的这些预测帧解码为最终的过渡视频。

训练与验证

数据集

loss

训练策略

有效

  1. 实验结果:论文通过实验验证了StoryDiffusion在生成一致性图像和过渡视频方面的性能,与最新的技术进行了比较,并展示了其优越性。

局限性

启发

  1. 应用潜力:StoryDiffusion的提出为可控图像和视频生成领域带来了新的启示,尤其是在需要讲述故事的应用场景中,如动画制作、游戏开发或虚拟现实体验。

遗留问题

参考材料

  1. 项目主页:https://StoryDiffusion.github.io