Code

https://github.com/CompVis/stable-diffusion

https://github.com/huggingface/diffusers

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model
model = load_model_from_config(config, opt["ckpt"])    
model.decode_first_stage()
model.encode_first_stage()

model.get_learned_conditioning()    # CLIP 编码

model.apply_model()                 # U-Net 预测当前图像应去除的噪声

p_sample_ddim(...)                  # 采样器计算下一去噪迭代的图像
pipe = load_stable_diffusion(sd_version=str(opt["model_id"]), precision_t=opt["dtype"])
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
latents = pipe.vae.encode(image).latent_dist.sample(generator) * pipe.vae.config.scaling_factor

pipe.encode_prompt()

pipe.unet(, return_dict=False)[0]

pipe.scheduler.step(, return_dict=False)[0]
import torch
import torch.nn as nn
from einops import rearrange, repeat
from inspect import isfunction
from diffusers.models.attention_processor import Attention, AttnProcessor

def default(val, d):
    if val is not None:
        return val
    return d() if isfunction(d) else d

def register_attention_control(model, controller):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if isinstance(to_out, nn.modules.container.ModuleList):
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward(x, encoder_hidden_states=None, attention_mask=None):
            torch.cuda.empty_cache()

            context = encoder_hidden_states
            mask = attention_mask

            h = self.heads

            q = self.to_q(x)
            context = default(context, x)
            k = self.to_k(context)
            v = self.to_v(context)

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
            sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale

            if mask is not None:
                mask = rearrange(mask, 'b ... -> b (...)')
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = repeat(mask, 'b j -> (b h) () j', h=h)
                sim.masked_fill_(~mask, max_neg_value)

            sim = sim.softmax(dim=-1)

            out = torch.einsum('b i j, b j d -> b i d', sim, v)
            out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

            return to_out(out)

        return forward

    def register_recr(net_, count, place_in_unet):
        if 'Attention' in net_.__class__.__name__:
            if net_.to_k.in_features == net_.to_q.in_features:
                net_.forward = ca_forward(net_, place_in_unet)
                return count + 1 
            else:
                return count
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count