이 글은 ViT 모델을 스크래치하는 방법을 정리한다.
VLM 개념도
현재 시점(2026.3)에서 최상위 오픈소스 소형 VLM 중 가성비 제일 좋은 것은 gemma-4, qwen vlm 모델이다. 단 gemma-4 는 아무리 작은 모델도 파인튜닝을 위해선는 32GB 이상 VRAM이 필요하다.
개인 로컬 GPU 에서는 SmolVLM, nanoVLM, 팔마 등이 그나마 적절한 VRAM(16GB) 에서 동작한다.
여기서는 VLM의 기반이 되는 ViT모델을 프롬 스크래치해본다. 가장 일반적인 예제인 CIFAR10 이미지 데이터셋을 이용해 VIT를 학습해 보기로 한다. 이 예시는 적은 VRAM(8GB)이내에서 가볍게 학습 방법을 알아보기 위한 목적으로 프롬 스크래치한다. 각 단계별로 주요 부분을 다음과 같이 코딩한다.
모델 하이퍼파라메터 설정
import os, torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
# 1. 하이퍼파라미터 설정
batch_size = 128 # 범위(32~128),
learning_rate = 1e-4 # Adam 추천값(0.001~0.0001)
epochs = 50 # 범위(10~50)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 모델 아키텍처 사이즈
img_size = 96
patch_size = 16
n_embd = 256 # 표현력 (n_head=8로 나눠떨어짐)
n_head = 8
num_blks = 6
emb_dropout = blk_dropout = 0.2
num_classes = 10 # CIFAR-10이므로 10개 클래스
데이터셋 로딩
# CIFAR-10 원본은 32x32 사이즈지만, 우리가 만든 VLM 구조(96x96)와
# 완벽하게 호환시키기 위해 Resize(96)을 적용.
# 학습 데이터에는 데이터 증강(Augmentation)을 추가하여 일반화 성능 향상
train_transform = transforms.Compose([
# CIFAR-10(32x32)을 96x96으로 업스케일: LANCZOS 보간으로 경계(edge) 흐림 최소화
transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.RandomCrop(img_size, padding=12), # 패딩 후 랜덤 크롭: 위치 변화에 강건
transforms.RandomHorizontalFlip(p=0.5), # 좌우 반전
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # 색상 증강
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.2)), # 랜덤 패치 제거: 특정 영역 과의존 방지
])
# 테스트 데이터는 증강 없이 원본 그대로 사용 (동일한 LANCZOS 보간 적용)
test_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print("데이터셋을 다운로드하고 준비합니다...")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 데이터셋 샘플 몇개 저장
dataiter = iter(trainloader)
images, labels = next(dataiter)
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
img = images[i].numpy()
img = img * 0.5 + 0.5
img = np.transpose(img, (1, 2, 0))
axes[i].imshow(img)
axes[i].set_title(f"Label: {classes[labels[i]]}")
axes[i].axis('off')
ouptut_fname = "./cifar10_samples.png"
plt.savefig(ouptut_fname)
plt.show()
로딩된 데이터셋 일부 예시는 다음과 같다.
# 이미지 패치를 임베딩
class PatchEmbeddings(nn.Module):
def __init__(self, img_size, patch_size, hidden_dim):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, X):
X = self.conv(X) # shape = (B, hidden_dim, H/patch_size, W/patch_size). (128, 256, 6, 6)
X = X.flatten(2) # flatten(2) means flattening the last two dimensions. shape = (128, 256, 36)
X = X.transpose(1, 2) # shape = (128, 36, 256)
return X
patch_emb = PatchEmbeddings(img_size, patch_size, n_embd).to(device)
print("입력 이미지 크기:", images.shape)
print("패치 임베딩 출력 크기:", patch_emb.forward(images.to(device)).shape)
ViT는 트랜스포머 구조를 그대로 활용한다. 이를 아래 ViT와 연계해 사용한다.
class ViT(nn.Module): # 비전 트랜스포머 모델
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout, num_classes):
super().__init__()
self.patch_embedding = PatchEmbeddings(img_size, patch_size, num_hiddens) # 이미지 패치를 임베딩 벡터로 변환
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens)) # 클래스 토큰 초기화 (학습 가능한 파라미터)
num_patches = (img_size // patch_size) ** 2
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens)) # 위치 임베딩 초기화 (학습 가능한 파라미터)
self.dropout = nn.Dropout(emb_dropout)
self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)]) # 트랜스포머 블록 여러 개 쌓기
self.layer_norm = nn.LayerNorm(num_hiddens) # 최종 레이어 정규화
self.classifier = nn.Linear(num_hiddens, num_classes) # 분류기(클래스 토큰을 최종적으로 클래스 확률로 변환)
def forward(self, X):
x = self.patch_embedding(X) # shape = (B, num_patches, num_hiddens) from patch embedding. (128, 3, 96, 96) -> (128, 36, 256)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # shape = (B, 1, 256)
x = torch.cat((cls_tokens, x), dim=1) # shape = (B, num_patches + 1, num_hiddens). (128, 37, 256)
x += self.pos_embedding # 위치 임베딩 추가. 초기값은 랜덤이지만 학습을 통해 최적화됨. (128, 37, 256)
x = self.dropout(x)
for block in self.blocks:
x = block(x)
cls_output = self.layer_norm(x[:, 0]) # 클래스 토큰의 최종 표현을 추출 (shape = (B, num_hiddens). (128, 256))
logits = self.classifier(cls_output) # 클래스 토큰을 최종적으로 클래스 확률로 변환 (shape = (B, num_classes). (128, 10))
return logits
모델 학습
print(f"1. 모델 초기화 (디바이스: {device})")
model = ViT(img_size, patch_size, n_embd, n_head, num_blks, emb_dropout, blk_dropout, num_classes).to(device)
# Adam에 weight decay 추가 (일반화 성능 향상)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
# [개선] CosineAnnealingLR: 코사인 곡선으로 학습률 감소 → ReduceLROnPlateau 대비 안정적 수렴
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
print("2. CIFAR-10 학습 시작")
best_val_loss = float('inf')
patience = 10 # [개선] 3→10: 코사인 스케줄러 환경에서 일시적 val loss 증가 허용
patience_counter = 0
for epoch in range(epochs):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 100 미니배치마다 로그 출력
if i % 100 == 99:
print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}")
running_loss = 0.0
# validation loss 계산 (테스트 데이터 사용)
model.eval()
val_loss = 0.0
with torch.no_grad():
for val_inputs, val_labels in testloader:
val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
val_outputs = model(val_inputs)
val_loss += criterion(val_outputs, val_labels).item()
val_loss /= len(testloader)
print(f"[Epoch {epoch + 1}] Validation Loss: {val_loss:.4f}")
# CosineAnnealingLR 스케줄러 업데이트 (매 epoch 호출)
scheduler.step()
# early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# 모델 저장 (optional)
torch.save(model.state_dict(), 'best_vit_model.pth')
else:
patience_counter += 1
if patience_counter >= patience:
print("Early stopping triggered.")
break
학습 결과
학습된 결과는 다음과 같다.
1. 모델 초기화 (디바이스: cuda)
2. CIFAR-10 학습 시작
[Epoch 1, Batch 100] Loss: 2.161
[Epoch 1, Batch 200] Loss: 2.026
[Epoch 1, Batch 300] Loss: 1.967
[Epoch 1] Validation Loss: 1.8028
[Epoch 2, Batch 100] Loss: 1.856
[Epoch 2, Batch 200] Loss: 1.797
[Epoch 2, Batch 300] Loss: 1.745
VLM 파인튜닝 레퍼런스
- Fine-Tuning Gemma 3 VLM using QLoRA for LaTeX-OCR Dataset
- Vision-language-models-VLM: vision language models finetuning notebooks (Medgemma - paligemma - florence .....)
- How to Fine-Tune Qwen3-VL on Your Own Dataset | Datature Blog (over 32g vram)
- VLM-LORA finetuning using OpenCLIP Workload — AMD Enterprise AI for robotics
- The Definitive Guide to Fine-Tuning a Vision-Language Model on a Single GPU (with code) with DORA | by Pavan Kunchala | Medium
- LoRA in Vision Language Models: Efficient Fine-tuning with LLaVA | by Phrugsa Limbunlom (Gift) | Artificial Intelligence in Plain English
- nanoVLM: The simplest repository to train your VLM in pure PyTorch
- SmolVLM - small yet mighty Vision Language Model
VLM 스크래치 레퍼런스
- Training a Vision Language Model from scratch (VLM multi-modal) | by Saptarshi MT | Medium
- Implementation of Vision language models (VLM) from scratch: A Technical Deep Dive. | by Achraf Abbaoui | Medium
- Wiring the Multimodal Mind: Building a Vision Language Model (VLM) from Scratch - Part 1 | by Priyanthan Govindaraj | Medium
- Vidit-Ostwal/VLM-from-scratch: This is majorly for my own learning purpose.
- Building a Nano Vision-Language Model from Scratch
- nipunbatra/vlm-from-scratch
- Building PaliGemma VLM From Scratch using Pytorch | by Shanmuka Sadhu | Jan, 2026 | Medium
오픈소스 라이브러리
ViT 개념 설명 레퍼런스
- Vision Transformer (ViT) from Scratch
- Vision Language Model from scratch in Pytorch #vlm - Qiita
- ViT Scratch Implementation - PyTorch
- Building Vision Transformers (ViT) from Scratch | by Maninder Singh | Medium
- 今井美樹 彼女と TIP ON DUO 歌詞 - 歌ネット
- Building a Vision Transformer from Scratch in PyTorch - GeeksforGeeks
- Training a Vision Transformer from Scratch on CIFAR-10:No Pre-training, No Problem | by Akshay Gokhale | Medium
- Vision Transformer For CIFAR-10
댓글 없음:
댓글 쓰기