# CIFAR-10 원본은 32x32 사이즈지만, 우리가 만든 VLM 구조(96x96)와
# 완벽하게 호환시키기 위해 Resize(96)을 적용.
# 학습 데이터에는 데이터 증강(Augmentation)을 추가하여 일반화 성능 향상
train_transform = transforms.Compose([
# CIFAR-10(32x32)을 96x96으로 업스케일: LANCZOS 보간으로 경계(edge) 흐림 최소화
transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.RandomCrop(img_size, padding=12), # 패딩 후 랜덤 크롭: 위치 변화에 강건
transforms.RandomHorizontalFlip(p=0.5), # 좌우 반전
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # 색상 증강
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.2)), # 랜덤 패치 제거: 특정 영역 과의존 방지
])
# 테스트 데이터는 증강 없이 원본 그대로 사용 (동일한 LANCZOS 보간 적용)
test_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print("데이터셋을 다운로드하고 준비합니다...")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 데이터셋 샘플 몇개 저장
dataiter = iter(trainloader)
images, labels = next(dataiter)
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
img = images[i].numpy()
img = img * 0.5 + 0.5
img = np.transpose(img, (1, 2, 0))
axes[i].imshow(img)
axes[i].set_title(f"Label: {classes[labels[i]]}")
axes[i].axis('off')
ouptut_fname = "./cifar10_samples.png"
plt.savefig(ouptut_fname)
plt.show()