| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 
 | apply_canny = CannyDetector()
 model = create_model('./models/cldm_v15.yaml').cpu()
 model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
 model = model.cuda()
 ddim_sampler = DDIMSampler(model)
 
 
 def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
 with torch.no_grad():
 
 img = resize_image(HWC3(input_image), image_resolution)
 H, W, C = img.shape
 
 detected_map = apply_canny(img, low_threshold, high_threshold)
 detected_map = HWC3(detected_map)
 
 control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
 control = torch.stack([control for _ in range(num_samples)], dim=0)
 control = einops.rearrange(control, 'b h w c -> b c h w').clone()
 
 if seed == -1:
 seed = random.randint(0, 65535)
 seed_everything(seed)
 
 if config.save_memory:
 model.low_vram_shift(is_diffusing=False)
 
 
 
 cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
 un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
 shape = (4, H // 8, W // 8)
 
 if config.save_memory:
 model.low_vram_shift(is_diffusing=True)
 
 model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
 samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
 shape, cond, verbose=False, eta=eta,
 unconditional_guidance_scale=scale,
 unconditional_conditioning=un_cond)
 
 if config.save_memory:
 model.low_vram_shift(is_diffusing=False)
 
 x_samples = model.decode_first_stage(samples)
 x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
 
 results = [x_samples[i] for i in range(num_samples)]
 return [255 - detected_map] + results
 
 |