File size: 3,979 Bytes
9042918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations

import os
import pathlib
import shlex
import shutil
import subprocess

import gradio as gr
import PIL.Image
import torch

os.environ['PYTHONPATH'] = f'lora:{os.getenv("PYTHONPATH", "")}'


def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
    w, h = image.size
    if w == h:
        return image
    elif w > h:
        new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
        new_image.paste(image, (0, (w - h) // 2))
        return new_image
    else:
        new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
        new_image.paste(image, ((h - w) // 2, 0))
        return new_image


class Trainer:
    def __init__(self):
        self.is_running = False
        self.is_running_message = 'Another training is in progress.'

        self.output_dir = pathlib.Path('results')
        self.instance_data_dir = self.output_dir / 'training_data'

    def check_if_running(self) -> dict:
        if self.is_running:
            return gr.update(value=self.is_running_message)
        else:
            return gr.update(value='No training is running.')

    def cleanup_dirs(self) -> None:
        shutil.rmtree(self.output_dir, ignore_errors=True)

    def prepare_dataset(self, concept_images: list, resolution: int) -> None:
        self.instance_data_dir.mkdir(parents=True)
        for i, temp_path in enumerate(concept_images):
            image = PIL.Image.open(temp_path.name)
            image = pad_image(image)
            image = image.resize((resolution, resolution))
            image = image.convert('RGB')
            out_path = self.instance_data_dir / f'{i:03d}.jpg'
            image.save(out_path, format='JPEG', quality=100)

    def run(
        self,
        base_model: str,
        resolution_s: str,
        concept_images: list | None,
        concept_prompt: str,
        n_steps: int,
        learning_rate: float,
        train_text_encoder: bool,
        learning_rate_text: float,
        gradient_accumulation: int,
        fp16: bool,
        use_8bit_adam: bool,
    ) -> tuple[dict, list[pathlib.Path]]:
        if not torch.cuda.is_available():
            raise gr.Error('CUDA is not available.')

        if self.is_running:
            return gr.update(value=self.is_running_message), []

        if concept_images is None:
            raise gr.Error('You need to upload images.')
        if not concept_prompt:
            raise gr.Error('The concept prompt is missing.')

        resolution = int(resolution_s)

        self.cleanup_dirs()
        self.prepare_dataset(concept_images, resolution)

        command = f'''
        accelerate launch lora/train_lora_dreambooth.py \
          --pretrained_model_name_or_path={base_model}  \
          --instance_data_dir={self.instance_data_dir} \
          --output_dir={self.output_dir} \
          --instance_prompt="{concept_prompt}" \
          --resolution={resolution} \
          --train_batch_size=1 \
          --gradient_accumulation_steps={gradient_accumulation} \
          --learning_rate={learning_rate} \
          --lr_scheduler=constant \
          --lr_warmup_steps=0 \
          --max_train_steps={n_steps}
        '''
        if fp16:
            command += ' --mixed_precision fp16'
        if use_8bit_adam:
            command += ' --use_8bit_adam'
        if train_text_encoder:
            command += f' --train_text_encoder --learning_rate_text={learning_rate_text} --color_jitter'

        with open(self.output_dir / 'train.sh', 'w') as f:
            command_s = ' '.join(command.split())
            f.write(command_s)

        self.is_running = True
        res = subprocess.run(shlex.split(command))
        self.is_running = False

        if res.returncode == 0:
            result_message = 'Training Completed!'
        else:
            result_message = 'Training Failed!'
        weight_paths = sorted(self.output_dir.glob('*.pt'))
        return gr.update(value=result_message), weight_paths