기존 트랜스포머에 Fast-dLLM의 개념을 적용하기 전, 두 모델의 근본적인 차이를 이해해야 한다.
오토-리그레시브 (AR) 트랜스포머: 이전 타임스텝의 출력이 다음 타임스텝의 입력이 되는 순차적 방식이다. 문장을 생성할 때, 이전에 생성된 단어들을 바탕으로 다음 단어 하나를 예측한다. 디코더의 Self-Attention에 Causal Mask를 적용하여 미래의 토큰을 참조하지 못하도록 막는 것이 핵심이다.
비-오토 리그레시브 (NAR) / Diffusion LLM: 문장의 여러 토큰, 혹은 전체 토큰을 동시에 생성하는 방식이다4444. Fast-dLLM이 기반하는 Masked Diffusion Model (MDM)은 [MASK] 토큰으로 가득 찬 시퀀스에서 시작하여, 여러 번의 정제(refinement) 단계를 거쳐 전체 문장을 완성한다5. 이 과정에서 디코더는 문장 전체의 맥락을 파악해야 하므로 양방향(Bidirectional) 어텐션을 사용한다.
이처럼 두 모델은 디코더의 동작 방식이 근본적으로 다르므로, 논문의 아키텍처를 그대로 이식하는 대신 학습 및 추론 과정을 시뮬레이션하는 방식으로 접근해야 한다.
Diffusion의 학습법: Masked Language Modeling
Diffusion 모델의 학습 목표는 노이즈가 낀 데이터에서 원본 데이터를 복원하는, 이른바 "Denoising" 과정이다. 텍스트 분야에서는 이 노이즈를 [MASK] 토큰으로 대체하여 구현한다. 이는 BERT의 MASK 방식 학습 아이디어와 유사해 보인다.
기존 AR 트랜스포머가 이 Denoising 능력을 학습하도록 데이터셋을 수정해야 한다. 타겟 문장의 일부를 랜덤하게 [MASK] 토큰으로 교체하고, 모델이 이 마스킹된 문장을 입력받아 원본 문장 전체를 예측하도록 학습 목표를 설정하는 것이다. 주요 구현을 의사코드로 확인해 보겠다.
class MaskedSeq2SeqDataset(Dataset):
def __init__(self, pairs: List[Tuple[str, str]], src_tok: Tokenizer, tgt_tok: Tokenizer, max_len: int = 40):
self.pairs = pairs
self.src_tok, self.tgt_tok = src_tok, tgt_tok
self.max_len = max_len
def __getitem__(self, idx):
src_txt, tgt_txt = self.pairs[idx]
src_ids = self.src_tok.encode(src_txt, add_eos=True)[:self.max_len]
tgt_ids = self.tgt_tok.encode(tgt_txt, add_sos=True, add_eos=True)[:self.max_len]
# 코사인 스케줄에 따라 마스킹할 개수 결정 (개념적인 시간 t에 해당)
t_rand = random.random()
num_to_mask = math.ceil(len(tgt_ids) * math.cos(t_rand * math.pi / 2))
masked_tgt_ids = list(tgt_ids)
maskable_indices = [i for i, t_id in enumerate(tgt_ids) if t_id not in (self.tgt_tok.sos_id, self.tgt_tok.eos_id)]
if len(maskable_indices) > 0 and num_to_mask > 0:
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
for i in indices_to_mask:
masked_tgt_ids[i] = self.tgt_tok.mask_id
return (torch.tensor(src_ids, dtype=torch.long),
torch.tensor(masked_tgt_ids, dtype=torch.long),
torch.tensor(tgt_ids, dtype=torch.long))
이에 대한 해결책으로 Confidence-Aware Parallel Decoding 전략을 제안한다. 이 전략은 모델이 예측한 확률 값, 즉 '자신감(confidence)'이 특정 임계값(threshold)을 넘는 토큰들만 선택적으로 예측하고, 나머지는 다음 스텝에서 다시 예측하도록 남겨두는 방식이다.
또한 추론 과정을 여러 블록(block)으로 나누어 점진적으로 생성하는 Block-wise Generation 방식을 채택했다.
@torch.no_grad()
def fast_dllm_decode(
model: NARTransformer,
src_tensor: torch.Tensor,
src_text: str,
tgt_tok: Tokenizer,
max_len: int,
num_blocks: int = 2,
steps_per_block: int = 12,
confidence_threshold: float = 0.9,
visualize: bool = False,
):
# ... (초기화 및 시각화 헬퍼 함수 생략) ...
# 블록 단위로 외부 루프를 순회한다.
for k in range(num_blocks):
start_idx = 1 + k * block_size
end_idx = min(start_idx + block_size, max_len)
# 아래 루프가 논문의 개념적인 '시간 t'의 흐름을 나타낸다.
# 블록 내에서 여러 스텝에 걸쳐 점진적으로 MASK를 채운다.
for t in range(steps_per_block):
step_count += 1
# 1. 모델을 통해 모든 위치의 토큰 확률을 예측한다.
logits = model.decode_step(ys, memory, tgt_pad_mask)
probs = F.softmax(logits, dim=-1)
confidences, predictions = probs.max(dim=-1)
# 현재 블록 내의 MASK 위치만 unmask 후보로 고려한다.
mask_positions = (ys == tgt_tok.mask_id) & current_block_mask
if not mask_positions.any(): break
# 2. Confidence가 임계값을 넘는 위치만 선택한다.
unmask_candidates = (confidences > confidence_threshold) & mask_positions
# 3. 만약 임계값을 넘는 토큰이 없다면, 가장 자신있는 토큰 하나만 선택하여 진행을 보장.
if not unmask_candidates.any():
masked_confidences = confidences.where(mask_positions, torch.tensor(-1.0, device=device))
if masked_confidences.max() > -1:
highest_idx = masked_confidences.argmax(dim=1, keepdim=True)
unmask_candidates.scatter_(1, highest_idx, 1)
if visualize:
_visual_print(...)
# 4. 선택된 위치의 MASK 토큰을 예측된 토큰으로 교체한다.
ys.masked_scatter_(unmask_candidates, predictions[unmask_candidates])
# ... (최종 결과 반환) ...
참고로 학습은 간단한 영-한 문장쌍 데이터셋을 간략히 구축해 진행하였고, 메커니즘만 확인할 목적으로 최소한의 GPU 리소스만 사용할 수 있도록 배치크기, 레이어 깊이 및 구조는 간략화된 버전으로 진행되었다.
이 글에서는 표준 AR 트랜스포머 아키텍처를 수정 없이 활용하면서, 학습 데이터 파이프라인과 추론 로직을 변경하여 Fast-dLLM 논문의 핵심 아이디어를 구현하는 방법을 살펴보았다.
레퍼런스
- Fast-dLLM: Training-free Acceleration of Diffusion LLM by Enabling KV Cache and Parallel Decoding
- NVlabs/Fast-dLLM: Official implementation of "Fast-dLLM: Training-free Acceleration of Diffusion LLM by Enabling KV Cache and Parallel Decoding"
- msarmi9/korean-english-multitarget-ted-talks-task · Datasets at Hugging Face
부록: 모델 개발 중 발생 가능한 이슈 해결법
예를 들어, 모델은 약 50만 개의 학습 가능한 파라미터를 가지고 있는데, 이처럼 방대한 학습 능력(capacity)을 가진 모델에게 적은 데이터셋 패턴을 학습시키는 것은 마치 대학 교수에게 알파벳만 외우게 하는 것과 같다.
결과 특정 에포크 지점에 도달하면, 모델은 훈련 데이터에 대한 손실(loss)을 거의 0에 가깝게 최소화한다. 이 상태는 모델이 더 이상 배울 것이 없는 '포화 상태'이다. 이후에도 학습을 계속하면, 옵티마이저는 더 이상 의미 있는 방향으로 가중치를 갱신하지 못하고, 아주 작은 그래디언트 변화에 따라 기존의 최적점에서 미세하게 벗어났다가 돌아오는 과정을 반복한다. 이것이 바로 손실 값이 안정적으로 수렴하지 못하고 불규칙하게 진동(vibration)하는 현상으로 나타나는 것이다.
2. 검증 기반 제어 장치
검증 세트의 본질적인 역할은 훈련 데이터에 포함되지 않은 데이터를 통해 모델의 일반화 성능(Generalization Performance)을 측정하는 것이다. 이 과정이 없으면 모델이 훈련 데이터에 얼마나 과적합되고 있는지 객관적으로 파악할 수 없다. 또한, 과적합이 시작되는 시점에 훈련을 자동으로 중단시키는 표준적인 기법인 조기 종료(Early Stopping)를 구현할 수 없다. 과적합이 발생하여 더 이상의 학습이 무의미해진 이후에도 모델은 불필요한 훈련을 계속 진행한다.
3. 부족한 정규화(Regularization) 기법
과적합을 억제하기 위한 장치로 드롭아웃(Dropout)이 적용되어 있기는 하다. 드롭아웃은 훈련 중에 무작위로 뉴런을 비활성화하여 모델이 특정 뉴런에 과도하게 의존하는 것을 막는 효과적인 기법이다. 모델의 가중치가 너무 커지는 것을 방지하여 과적합을 억제하는 가중치 감쇠(Weight Decay)와 같은 다른 보편적인 정규화 기법이 부재할 수 있다. 부족한 정규화는 모델이 제한된 훈련 데이터의 패턴을 더 빠르고 쉽게 암기하도록 만들어 과적합을 가속화하는 요인으로 작용한다.
이러한 문제들을 해결하고 모델을 안정적으로 훈련시키기 위한 방안은 다음과 같다.
1. 데이터셋 교체 및 증강
실제 대용량 데이터셋을 사용해야 한다. 예를 들어, 허깅페이스에 공개된 데이터셋은 수만 개 이상의 문장 쌍으로 구성되어 있어, 모델이 일반화된 언어 패턴을 학습하는 데 필수적이다.
2. 검증 루프 및 조기 종료 구현
과적합을 방지하고 훈련 효율성을 높이기 위해 검증 및 조기 종료 로직을 도입 한다. 이는 가장 표준적이고 효과적인 방법이다.
# 조기 종료 로직
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
# 여기서 최고 성능 모델의 가중치를 저장하는 것이 좋음
else:
patience_counter += 1
if patience_counter >= patience:
print("Early stopping due to no improvement in validation loss.")
break
3. 학습률 스케줄러 및 가중치 감쇠 추가
Learning Rate Scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau와 같은 스케줄러를 추가하면, 검증 손실이 정체될 때 학습률(learning rate)을 동적으로 낮추어 모델이 최적점에 더 안정적으로 수렴하도록 도울 수 있다.
Weight Decay: 옵티마이저를 생성할 때 weight_decay 파라미터를 추가하여 L2 정규화를 적용한다. 이는 모델의 가중치가 너무 커지는 것을 방지하는 역할을 한다.
# 옵티마이저 생성 시 weight_decay 추가
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)