Code
https://github.com/CompVis/stable-diffusion
https://github.com/huggingface/diffusers
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
model = load_model_from_config(config, opt["ckpt"])
model.decode_first_stage()
model.encode_first_stage()
model.get_learned_conditioning() # CLIP 编码
model.apply_model() # U-Net 预测当前图像应去除的噪声
p_sample_ddim(...) # 采样器计算下一去噪迭代的图像
pipe = load_stable_diffusion(sd_version=str(opt["model_id"]), precision_t=opt["dtype"])
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
latents = pipe.vae.encode(image).latent_dist.sample(generator) * pipe.vae.config.scaling_factor
pipe.encode_prompt()
pipe.unet(, return_dict=False)[0]
pipe.scheduler.step(, return_dict=False)[0]
import torch
import torch.nn as nn
from einops import rearrange, repeat
from inspect import isfunction
from diffusers.models.attention_processor import Attention, AttnProcessor
def default(val, d):
if val is not None:
return val
return d() if isfunction(d) else d
def register_attention_control(model, controller):
def ca_forward(self, place_in_unet):
to_out = self.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
torch.cuda.empty_cache()
context = encoder_hidden_states
mask = attention_mask
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return to_out(out)
return forward
def register_recr(net_, count, place_in_unet):
if 'Attention' in net_.__class__.__name__:
if net_.to_k.in_features == net_.to_q.in_features:
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
else:
return count
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = model.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")
controller.num_att_layers = cross_att_count