0%

Training Control Net with fill50k Dataset

如何使用自定义的数据集训练自己的control net:

ControlNet/docs/train.md at main · lllyasviel/ControlNet (github.com)

由于别的数据集都太大了,打算直接用先用Fill50K 来试一下

image-20240429210739733

模型已经知道的:已经知道什么是“青色”,什么是“圆形”,什么是“粉红色”,什么是“背景”

不知道的:control image的图像的含义

模型目的:训练模型使得其能够正确往圆圈和背景里填正确的颜色

首先下载,然后解压到

1
2
3
ControlNet/training/fill50k/prompt.json
ControlNet/training/fill50k/source/X.png
ControlNet/training/fill50k/target/X.png

接着测试一下有没有成功读进来:(使用tutorial_dataset.py)

image-20240430172317915

发现是可以读进来的

然后选择一个预训练好的stable diffusion模型:

runwayml/stable-diffusion-v1-5 at main (hf-mirror.com)

里的“v1-5-pruned.ckpt”

然后就可以使用

1
python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt

来将处理后的模型 (SD+ControlNet) 保存在“./models/control_sd15_ini.ckpt”位置

image-20240430224930701

最后就可以开始train了

运行tutorial_train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict


# Configs
resume_path = './models/control_sd15_ini.ckpt'
batch_size = 3
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False


# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control


# Misc
dataset = MyDataset()
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])


# Train!
trainer.fit(model, dataloader)

发现CUDA out of memory

查了发现24g的显存都爆了,只能把batch size设成3,勉强够用。。。

最终运行了大概1.3个epoch

效果:

按时间顺序从后往前排序

image-20240503194300398

image-20240503194324822

四个一组,分别是prompt,input,ground truth, output

可以看到随着训练的增加,一开始的输出非常接近真实世界(权重都集中在原本的模型上),带有真实物体的花纹等等;后面无论是

  1. 填充的颜色
  2. 圆圈的位置和大小

逐渐能够接近ground truth

证明训练是有效的

训练后的模型文件在ControlNet\lightning_logs\version_0\checkpoints中