Norm
FLUX, QWEN-IMAGE 모듈 중 norm을 알아보자.
- (내가아는 오픈소스중 2황이기 때문에…)
| API | FLUX | QWEN-IMAGE |
|---|---|---|
| VAE | Group Norm | RMSNorm |
| LDM | LayerNorm, RMSNorm | RMSNorm, AdaLayerNorm(?) |
RMSNorm 이놈이 챔피언인거같은데, 파보기 전에 틀딱 Norm들 먼저 리마인드 하고 가야겠음.
1. 예전 Norm들.
학부, 대학원 때 개~xx 이미지로 각종 Norm 기법 비교해놓은 이미지 기억나는데 당시에 이해도 못했고, 그냥 숫자, 테서로 파악하는게 나을꺼같아서 다시 정리해본다.
1.1 BatchNorm
CNN에서 특히 강력하고 RNN/Transformer 같은 sequence 모델과 궁합 구림. 그래서 아예 사장되었나보다.
특히 BatchNorm은 배치 사이즈 작으면 아예 학습이 안되는데 최소 64개였나? 그정도 필요하다고 했던거같다.
예전에 GAN 학습할 때 InstanceNorm 대신 BatchNorm batch=8로 학습했었을 때 이미지가 엄청 뿌옇게 거의 스모커 대령처럼 나오는걸 확인한 적이 있었음.
그러나 이것부터 좀 알아야 각종 변형 Norm들이 구체적으로 파라미터 어떻게 생겼는지 탁 보고 감 잡힐듯.
- 입력: (B, C, H, W), $\mu, \sigma$: (C,)
배치 싹 끌어와서 통계를 내다. (C, B x H x W )로 생각하면 편할듯. CNN에서 image latent channel별로 다른 특성을 지닌다고 가정하기에 채널별로만 평/분 계산 후 뿌려준다.
요약하면
“해당 축에 있는 값들이 같은 의미의 feature인가?” 를 기준으로 Norm을 진행한다. 같은 의미를 지닌 애들끼리 평/분 정규화. CNN에서 C(channel)는 “의미가 고정된 feature map”이다.
- C=0 → edge detector
- C=1 → texture detector
- C=2 → color mixing
1.2 InstanceNorm
이건 특징이 이미지에서 채널별로 평/분을 구한다. 이렇게 되면 이미지별로 독특하고, 두드러지는 특징들이 완화 되는데, 그렇게 perceptual한 특징을 날려버리는 순간 이미지 복원시 매우 안좋은 이미지가 생성될 꺼같다.
그러나 뭔가 주요 특징만 살리고 그 외에껄 바꾸고싶은 edit, image-to-image 태스크에서는 뭔가 InstanceNorm이 도움될꺼 같은 생각이든다.
- 입력: (B, C, H, W), $\mu, \sigma$: (B, C, 1, 1)
(B, C, HxW) 로 만들어놓고 배치, 채널 단위로 평/분을 때려버린다. 이미지마다 가지는 고유한 특징들이 완화된다.
1.3 GroupNorm
nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) # FLUX-VAE GroupNorm
FLUX-VAE, 그리고 SDXL VAE+Unet 에서 쓰는 norm.
QWEN에서 안쓰는거 보면 뭔가 꼬롬하다.
얘는 InstanceNorm의 문제를 완화하려고 나왔다.
- 너무 aggressive → style/contrast를 많이 없애버림 (HxW 한장을 통째로 Normalize 해버림)
- CNN task에서 성능 저하
channel을 k개의 그룹으로 나눠서 InstanceNorm을 한다.
InstanceNorm이 (B, C, HxW) 로해서 (B, 1, HxW) 를 C개 쌓는 방식으로 Norm을 한다고 헸다.
GroupNorm은 channel을 k개로 토막친 후 Norm한다.
(B, C1, H, W),(B, C2, H, W), …,(B, Ck, H, W)토막(B, 1, H1 x H2 ... x Hn x W1 x W2 ... x Wn)이걸 k개 쌓는다.
이렇게 하면 한채널이 뭉그러지는 강도가 현저히 줄게되어서 이미지 latent 채널 별 특징 뭉개짐이 줄어든다.
2. 요즘 핫한 Norm
요즘 핫한 Norm은 LayerNorm과 RMSNorm 이 분명하다. 이게 Transformer구조에서 상당히 유리한갑다.
2.1 LayerNorm
nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
얘는 그나마 다행인게 따로 튜닝없이 torch 자연빵 함수 그대로 쓴다.
표준이 시계열 데이터인거 같으니 이제부터 시계열 형태의 Norm으로 본다.
한 token의 feature 전체를 동시에 scale-invariant하게 만든다.
쉽게 그냥 (B, S, D) 에서 D의 평/분 구한후 norm 을 한다.
여기서 드는 의문은 원래 다르게 임베딩 된 애가 norm 거치고 나면 비슷하게 임베딩이 되는데 이게 괜찮은건가 하는 궁금증.
# 입력
[1, 2, 3, 4]
[4, 5, 6, 7]
# Norm 후
[-1.34, -0.45, 0.45, 1.34]
[-1.34, -0.45, 0.45, 1.34]
예를들어 이렇게 생긴 (1, 2, 4) 짜리 (B, S, D) 가 있을 때 LayerNorm 때리면 동일한 임베딩이 나온다.
ChatGPT가 말하길 스케일은 줄지만 방향은 안바뀌니 상관없다. “Transformer는 direction이 중요하기 때문에 스케일 알빠노” 라고 한다. Attention 할 때 softmax가 스케일에 매우 민감하기 때문에 완만한 distribution 만드는게 중요하다. 스케일은 버린다. 좀 이런식이다.
쪼금 꼬롬하긴 하지만 그런갑다 하고 넘어가야겠음ㅇㅇ
# QWEN-IMAGE Norm
class AdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# QWEN-Image 몸통
class QwenImageTransformer2DModel(nn.Module):
def __init__(
self,
params: QwenParams,
):
...
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
...
def forward(self, hidden_states, ..., timestep, ...):
...
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
return output
추가로 이건 QWEN-IMAGE 에 몸통 Transformer를 구성하는 AdaNorm클래스인데, 최종 output latent에 timestep embeding을 같이 넣어서 norm을 한 뒤 output project에 집어넣어 최종 결과를 뽑는다.
솔직히 왜 이런구조 했는지 모르겠음. temb에 왜 norm을 하는지 이유가 뭘까.
2.2 RMSNorm의 효능
솔직히 LayerNorm보다 얘가 더 중요할듯하다.
# FLUX RMSNorm
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
# return (x * rrms).to(dtype=x_dtype) * self.scale
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
# QWEN-IMAGE RMSNorm
class RMSNorm(nn.Module):
r"""
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
Args:
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
bias (`bool`, defaults to False): If also training the `bias` param.
"""
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
self.weight = None
self.bias = None
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
if bias:
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, hidden_states):
# if is_torch_npu_available():
# import torch_npu
# if self.weight is not None:
# # convert into half-precision if necessary
# if self.weight.dtype in [torch.float16, torch.bfloat16]:
# hidden_states = hidden_states.to(self.weight.dtype)
# hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
# if self.bias is not None:
# hidden_states = hidden_states + self.bias
# else:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
elif self.weight.dtype == torch.float8_e4m3fn: # fp8 support
hidden_states = hidden_states * self.weight.to(hidden_states.dtype)
hidden_states = hidden_states + (self.bias.to(hidden_states.dtype) if self.bias is not None else 0)
hidden_states = hidden_states.to(input_dtype)
return hidden_states
hidden_states = hidden_states * self.weight
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
정말 흉악하게 생겼다. nn.RMSNorm도 분명히 존재하는데 굳이 왜이렇게 해놨을까. 해답은 주석의 논문에 있겠으나 저건좀… 나중에 본다.
RMS (Root Mean Square)는 제곱평균제곱근으로 요란한 이름만큼 계산도 요란하다.
- 값들을 제곱
- 제곱한 값들의 평균을 계산
- 평균에 제곱근을 취함
우선 RMS 라는 것만 보면 아래의 두가지 장점이 있음.
- 계산량이 적다.
- 방향은 안건들고, 크기만 본다.
흔히 표준편차 계산을 보면
- 평균 계산
- $(x-\mu)^2$ 계산
- 다 더하고
- 제곱근
총 4단계 + 평균 계산 필요.
RMS는
sqrt(mean(x**2))
으로 매우 간단하다.
두번 째, LayerNorm에 비해 방향이 안변한다. LayerNorm은 평균을 각 벡터에 빼주기 때문에 벡터 방향이 아주 살짝 변한다. Transformer에서는 토큰 벡터의 방향이 매우 중요한데 이게 미세하게 깨지는 현상이 발생. 하지만 RMSNorm은 길이만 정규화하기 때문에 방향이 100% 유지되며 이는 Attention 연산에서 매우 중요한 효능으로 발휘된다.
최종 요약하면
- RMS(x) = sqrt(mean(x²)) 는 벡터의 전반적인 크기(에너지)를 빠르고 안정적으로 나타내는 통계량이다.
- 평균을 빼지 않아서 벡터 방향을 바꾸지 않으며, Transformer에서 표현력이 유지된다.
- 계산이 간단하고 수학적으로 안정해서 딥 Transformer/LLM에서는 LayerNorm보다 성능이 더 좋다.
근데 왜 nn.RMSNorm 안씀?
MSNorm은 간단해 보이지만, “어떻게 eps 쓰는지, 어떤 축을 normalize하는지, float cast를 언제할지”에 따라 결과가 달라짐.
우선 pytorch RMSNorm은 2023년에 들어온 비교적 신참이라 실전 검증이 덜되었다고 한다. 이게 문제가 뭐냐하면 학습때의 precision과 추론 때의 precision (ONNX/fp8) 차이는 상당히 큰데, 검증되지 않은 torch.nn.RMSNorm은 쓰기엔 좀 무리라는거다.
즉, FP16/ BF16/ FP32 cast 경로를 명확히 통제 하면서 여러가지 호환을 위해 직접 구현을한다고 하며 eps 위치에 따라 결과가 또 많이 달라진다고 한다.
FLUX만 봐도 실제 계산은 (x**2).mean()**0.5 이거와 진배 없지만 dtype 막 건들면서 각기를 친다.
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
# return (x * rrms).to(dtype=x_dtype) * self.scale
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
QWEN의 경우 조금더 깊게 고민을 하는데
- dim을 tuple로 받을 수도 있게 설계
- fp8/half 지원
- weight/bias 옵션
- 다양한 디바이스(cpu/rocm/npu)에서 kernel fallback 지원
elementwise_affine은 temb 받아서 modulation 연산을 위한 용도다. 이거 할 때도 그냥 하지 않고 dtype 크게 바꾼후 해준다.
요약하면 얘네도 만들고 싶어서 만든게 아니고, 아직 torch.nn.RMSNorm이 불안정하기에 (추론/학습 환경) 어쩔 수 없이 내부에서 dtype, deivce를 고려한 설계를 해주고 있는것.
3. Modulation 연산
GAN 때부터 있던 외부 condition을 내부로 주입하는 용도로 쓰였던 Norm 기법과 비슷한디, Diffusion Transformer 구조에서 외부 컨디션 (timestep, guidance 등) sinsoidal 임베딩 후 scale, bias 해준다.
SDXL 까지만 해도 외부 임베딩을 그냥 덧셈만 해주었는데 SD3 구조에서 mod라는 이름으로 처음 봤던 구조다.
정식 명칭은 Feature-wise Linear Modulation 로 서로다른 정보를 결합하기 위한 변환 연산이다.
FLUX에서 modulation
우선 임베딩 벡터를만든다.
- timestep
- guidance scale (guidance distill을 해서 얘가 추가됨)
- clip_l head output (text 내용 통합 latent)
이거 3개로 vec라는 이름의 임베딩 latent를 만들고 nn.Linear를 이용해 아래의 구조를 거쳐서 shift, scale, gate 구조를 만들어준다.
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
그리고 만들어 둔 vec로 shift, bias 이런 연산들을 해줌.
def forward(~~) :
img_mod1, img_mod2 = self.img_mod(vec)
...
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
...
# 각종 Layer, Attention 연산 거친 후
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
여기서 의문접은 2가지.
- 왜
shift * x + bias가 아니고(1+shift) * x + bias임? ->shift가 0이면 변하면 안되는데 그냥 곱하면 0된다. gate는 뭐하는놈? -> LSTM처럼 residual 기여를 조절해주는 장치로x = x + f(x)대신 안정성을 위해x = x + gate * f(x)를 써준다.
QWEN에서의 Modulation
def __init__(~~) :
self.img_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
)
...
def _modulate(self, x, mod_params):
"""Apply modulation to input tensor"""
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(~~) :
img_mod_params = self.img_mod(temb) # [B, 6*dim]
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
...
hidden_states = hidden_states + img_gate1 * img_attn_output
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
del img_normed2, img_mod2
img_mlp_output = self.img_mlp(img_modulated2)
del img_modulated2
hidden_states = hidden_states + img_gate2 * img_mlp_output
구조가 상당히 변태같은데 어텐션 집어넣기 전에 mod + gate 해주고, 나오고나서 mod + gate 해주고난 다음에 걔를 또 gate해서 hidden_state에 더해준다.
정말 조심스럽게 연산하고 더해주는 방식인데, 게다가 모델도 존나게 크다. 이거를 어떻게 학습했나 모르겠다…