๊ฒ์ ์์ญ ์์ฑํ๋ ๋ชจ๋ธ ํ์ต์์ train ํ๊ฒฝ bfloat16 ์์๋ ๋ฌธ์ ์๋๊ฒ ํ
์คํธํ๊ฒฝ fp8๋ก ์์ฑํ๋ ํ
๋๋ฆฌ ๋ถ๋ถ์ด ๊ฐ๋ฐ์ด๋๋ ํ์. ์ด๊ฑด ์์ง๋ ํด๊ฒฐ๋ฐฉ๋ฒ ๋ชป์ฐพ์ใ
ใ
(๊ฑฐ์ ํ๋ฌ๋๊ฒ ์คํ๋ฐ๊ฟ๊ฐ๋ฉฐ ํ์ต/์คํ ํ๋๋ฐ ํด๊ฒฐ ๋ชปํจใ
).
์ด๊ฒ ์๋ฌด๋๋ ๋์ precision์์ ๋ฎ์ precision์ผ๋ก ๋ด๋ ค๊ฐ๋ฉด์ ๋ฐ์ํ๋ ๋ฌธ์ ๊ฐ ํ์คํด ๋ณด์ด๋๋ฐ, ๊ทผ๋ฐ ํํ์ด๋ฉด ๋ฑ ๊ฒฝ๊ณ๋ฉด์ ์ ๋ ๊ฒ ํญ๊ฒฉ์ด ๊ฐํด์ง๋ ์์ธ์ด ๋ญ๊น?
์ด๊ฑธ ํ์
ํ๊ธฐ ์ํด์๋ precision ๋ฎ์ถ๋ ๊ณผ์ ์ ์ข low-level ์์ ํ์
ํ ํ์์ฑ์ด ๋๊ปด์ง๋ค.
์๋์ ๊ทธ๋ํ๋ค์ LoRA ํ์ตํ ๋ ๋ด๊ฐ ๊ฒช์ ๋ค์ํ Prodigy Optimizer์ lr ์ค์ผ์ค๋ง ๋ณํ ๊ทธ๋ํ๋ค.

FLUX, QWEN-IMAGE ๋ชจ๋ ์ค norm์ ์์๋ณด์.
| API | FLUX | QWEN-IMAGE |
|---|---|---|
| VAE | Group Norm | RMSNorm |
| LDM | LayerNorm, RMSNorm | RMSNorm, AdaLayerNorm(?) |
RMSNorm ์ด๋์ด ์ฑํผ์ธ์ธ๊ฑฐ๊ฐ์๋ฐ, ํ๋ณด๊ธฐ ์ ์ ํ๋ฑ Norm๋ค ๋จผ์ ๋ฆฌ๋ง์ธ๋ ํ๊ณ ๊ฐ์ผ๊ฒ ์.
koyha-ss ๊ฐ ์ด ๋ค์ํ accelerate ์ค์ ์ฝ๋๋ฅผ ์ดํด๋ณธ๋ค.
acclerate config๋ accelerate ํ๋๊ทธ๋ก ์ฐ๊ธฐ ์ซ ๋ณต์กํ์ ๋ค์ ์ฝ๋๋ก ์ธํ
ํ๋ค.
๊ณต๋ถํ๊ณ ๋๋๊น ์ฝ๋๊ฐ ์ฝํ๋ค.
๋ถ์ฐํ์ต์ด ๋๋์ฒด ๋ญ์ง ์์ง๋ ์ฌ์ค ์ ๋ชจ๋ฅธ๋ค. GPU ์ฌ๋ฌ๊ฐ์ ๋ชจ๋ธ ๋ณต์ฌํด ๋๊ณ backprop ํ ๋ gradiend๋ง ์ทจํฉํด์ ๋ณด๋ด๋๊ฑด์ง, GPU ์ฌ๋ฌ๊ฐ๋ฅผ ํ๋์ฒ๋ผ ์ฐ๋๊ฑด์ง, ์ด๋ป๊ฒ ์ฐ๋๊ฑด์ง ๋ฑ๋ฑ.
์ด์ฐธ์ ํ์คํ ์ ๋ฆฌํ ์์ .
| ๊ตฌ๋ถ | ๊ฐ๋ |
|---|---|
| Single GPU | ํ GPU๋ก ํ์ต |
| Data Parallel (๋ฐ์ดํฐ ๋ณ๋ ฌ) | ๋ชจ๋ธ ๋ณต์ฌ N๊ฐ โ ๊ฐ GPU๊ฐ ๋ค๋ฅธ ๋ฏธ๋๋ฐฐ์น ํ์ต |
| Model Parallel (๋ชจ๋ธ ๋ณ๋ ฌ) | ๋ชจ๋ธ ์์ฒด๋ฅผ ์ฌ๋ฌ GPU์ ๋๋ ์ ์ ์ฅ |
| Pipeline Parallel | ๋ชจ๋ธ์ ์ฌ๋ฌ ํํธ๋ก ๋๋๊ณ ํ์ดํ๋ผ์ธ์ฒ๋ผ ์์ฐจ ์คํ |
| Tensor Parallel | ํ ๋ ์ด์ด ๋ด๋ถ ์ฐ์ฐ์ ์ฌ๋ฌ GPU๊ฐ ๋๋ ์ฒ๋ฆฌ |
| Multi-Node (๋ถ์ฐ ํ์ต) | GPU ์ฌ๋ฌ ๊ฐ๊ฐ ์ฌ๋ฌ ์๋ฒ(๋ ธ๋) ์ ํฉ์ด์ ธ ์์ |
| FSDP / ZeRO (DeepSpeed) | ๋ชจ๋ธ, ๊ทธ๋๋์ธํธ, ์ตํฐ๋ง์ด์ ๋ฅผ GPU ๊ฐ ๋ถํ ์ ์ฅ |
| Hybrid Parallel | ์ ๋ณ๋ ฌ ๋ฐฉ์์ ํผํฉ |
๋ด๊ฐ ์ ๋ง ๋ฆฌ์คํ ํ๋ฉด์ ์ฅ์ด ํจ๊ณ ์ถ์ koyha-ss ์ฝ๋๋ฅผ ๊ณต๋ถํ๋ฉฐ ์๊ฒ๋ ํ๋์จ์ด ๊ฐ์๊ธฐ Accelerator. ์ฃผ๋ก ์ฐ๋๊ฑด autocast, multi-gpu ์ด์ ๋๋ง ์๊ณ ๋๋จธ์ง๋ ์ค๊ฐ ์จ๋์๋๋ก๋ง ์จ์์.
์ฒ์ ์ธํ
ํ ๋ accelerater config -> ์ฝ๋ ๋ด๋ถ์ ๋ค์ํ accelerater ํจ์๋ค (unwrap, log, prepare ๋ฑ๋ฑ) -> ๊ทธ๋ฆฌ๊ณ python ์คํํ ๋ ์ฐ๋ accelerater ์ต์
๋ค (โmix-precision ๊ฐ์๊ฑฐ).
kohya-ss ์ด ์ฌ๋ ์ฝ๋ ๊ธฐ์ค์ผ๋ก ํ๋ํ๋ ์์๋ณด๋ ค๊ณ ํ๋ค.
์ง์ฅ์ํ 1๋
10๊ฐ์ ์ ๋ํ๋ฉด์ ํ๋ก์ ํธ ์์ฑ์๋ง ์ด์ ์ ๋๋๋ผ ๋์ถฉ ์ฅ ๋ณด๊ณ ๋์ด๊ฐ๋ ๊ฐ๋
๋ค์ด ๋ง์ ์ง๋์จ๊ธธ, ์จ๋ดค๋๊ฑฐ ๋์๋ณด๋ฉฐ ๊ฐ๋
์ข ๋ค์ ์ก๊ณ ์ ํฌ์คํ
ํ๊ฒ๋์์.
ํนํ ์ฝ๋ ๋ณด๋ฉด์ โ์ ๊ทธ๋ฐ๊ฐ๋ค~โ ํ๊ณ ๋์ด๊ฐ๋ ๋๋ค ์์ฃผ๋ก ๋ค์ ๋ฆฌ๋ง์ธ๋ ํ๋ฉด์ ๊ณ์ ์ ์ด๋๊ฐ ์์
def update_links(links, value):
for h, w in links:
table[h][w][0] = value
def init_links(links):
for h, w in links:
table[h][w] = ['EMPTY', {(h, w)}]
def update(command):
if len(command) == 3:
h, w, value = command
h, w = int(h) - 1, int(w) - 1
_, links = table[h][w]
update_links(links, value)
else:
value1, value2 = command
for h in range(n) :
for w in range(n) :
value, links = table[h][w]
if value == value1 : table[h][w][0] = value2
def merge(command):
h1, w1, h2, w2 = map(int, command)
if h1 == h2 and w1 == w2 : return
h1, w1, h2, w2 = h1-1, w1-1, h2-1, w2-1
value1, links1 = table[h1][w1]
value2, links2 = table[h2][w2]
links = links1.union(links2)
value = value1 if value1 != "EMPTY" else value2
update_links(links, value)
for h, w in links :
table[h][w][1] = links
def unmerge(command):
h, w = map(int, command)
h, w = h - 1, w - 1
value, links = table[h][w]
init_links(links)
table[h][w][0] = value
def print_(command):
h, w = map(int, command)
h, w = h - 1, w - 1
value, _ = table[h][w]
return value
# ํ
์ด๋ธ ๋ฐ value_storage ์ด๊ธฐํ
n = 50
table = [[["EMPTY", {(h, w)}] for w in range(n)] for h in range(n)]
def solution(commands):
answer = []
for command in commands:
command_type, values = command.split(' ', 1)[0], command.split(' ')[1:]
if command_type == 'UPDATE':
update(values)
elif command_type == 'MERGE':
merge(values)
elif command_type == 'UNMERGE':
unmerge(values)
elif command_type == 'PRINT':
answer.append(print_(values))
return answer
def expand_bin(bin_number) :
i = 0
while len(bin_number) > 2**(i+1) - 1 :
i += 1
return '0'*(2**(i+1) - 1 - len(bin_number)) + bin_number
def possiblity_check(bin_number) :
center = len(bin_number) // 2
if len(bin_number) == 1 :
return True
elif bin_number[center] == '0' :
return not ('1' in bin_number[:center] or '1' in bin_number[center + 1:])
else:
return possiblity_check(bin_number[:center]) and possiblity_check(bin_number[center + 1:])
def solution(numbers):
answer = []
for number in numbers :
bin_number = str(bin(number))[2:]
bin_number = expand_bin(bin_number)
answer.append(possiblity_check(bin_number)*1)
return answer