2024년 2월 18일 일요일

파이토치로 멀티모달 생성AI 모델, Stable Diffusion 아키텍처를 코딩, 구현해보기

이 글은 Text-To-Image 생성AI 기술을 대중화시킨 Stable Diffusion 아키텍처를 파이토치로 직접 개발해본다. 이를 통해, 앞서 분석한 스테이블 디퓨전 아키텍처를 어떻게 동작하는 코드로 맵핑했는 지를 확인한다. 

스테이블 디퓨전 모델은 기존 자연어 처리 분야에서 개발된 트랜스포머, 컴퓨터 비전 딥러닝 모델 기술을 적극 사용한다. 이와 관련된 내용을 깊게 이해하고 싶다면 다음을 참고한다. 
이 글은 스테이블 디퓨전의 핵심 모델인 트랜스포머, CLIP 등의 이해를 필수로 한다. 이는 다음 글을 참고한다. 
이 글은 멀티모달, 디퓨전과 관련된 다양한 문헌을 참고해 정리된 것이다. 관련 내용은 이 글의 레퍼런스를 참고한다. 참고로, 이 글의 소스 파일은 github에 공개된 Binxu의 코드(참고. Kempner Institute, Harvard University)와 Fareed Khan 코드(2024)를 참고한 것이다. 

이 글 소스 파일은 다음 github 링크를 통해 다운로드 가능하다. 참고로, 이 코드는 GPU 드라이버가 설치되어 있는 컴퓨터, PyTorch가 설치된 conda 개발 환경에서 실행될 수 있다. 
스테이블 디퓨전 아키텍처 모델의 깊은 이해가 아닌, 프로그램 설치 사용 만을 원한다면 아래 링크를 참고한다.
스테이블 디퓨전 아키텍처 구성요소
앞의 글에서 설명했듯이, 스테이블 디퓨전은 다음 그림과 같이 디퓨전 모델, U-Net, 오토인코더(Autoencoder), 트랜스포머(Transformer) 어텐션 모델(Attention)을 사용해, 학습한 모델을 역으로 계산해 주어진 텍스트 조건에서 이미지가 생성될 수 있도록 한다.
스테이블 디퓨전 아키텍처 기반 Text To Image 생성(Inference) 및 학습 과정

이 글에서 구현할 아키텍처 구성요소를 나열해 본다. 
  • VAE 오토인코더: 입출력 데이터를 잠재공간차원에 맵핑. 계산 성능을 높임
  • 순방향 확산(forward diffusion): 입력 이미지에서 점진적으로 노이즈 이미지로 계산. 학습용 데이터로 사용.
  • 역방향 확산(Reverse diffusion): 노이즈에서 이미지를 생성
  • U-Net: 노이즈 예측 학습용
  • 컨디셔닝: 텍스트에 따른 조건부 이미지 생성용. CLIP모델 같은 트랜스포머 어텐션 사용

앞의 관련글에서 설명하였듯이, 스테이블 디퓨전 데이터 학습은 A100 x 50에서도 최소 몇 일은 걸리는 작업이다. 참고로, 독일 뮌헨 대학에서 학습한 규모의 이미지 데이터량을 학습하려면, 개인이 하기에는 비싼 비용과 시간이 필요하므로, 이 글에서는 MNIST와 같은 소형 데이터셋을 대상으로 학습하여, 학습되는 파라메터 수를 GPU 2GB 내에서 계산될 수 있도록 한다. 

이 글에서 테스트된 스테이블 디퓨전 학습 시 필요한 GPU RAM 표시 

아키텍처와 관련된 상세한 설명은 다음 링크를 참고한다. 

학습 데이터 준비
MNIST 데이터를 준비한다. 이를 위해, torchvision 라이브러리를 사용해 데이터를 다운로드하고, 확인해 본다. 

import torch, torchvision, matplotlib.pyplot as plt
from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

unique_images, unique_labels = next(iter(train_loader))
unique_images = unique_images.numpy()

row, column = 4, 16
fig, axes = plt.subplots(row, column, figsize=(16, 4), sharex=True, sharey=True)  

for i in range(row):  
    for j in range(column):  
        index = i * column + j  
        axes[i, j].imshow(unique_images[index].squeeze(), cmap='gray') 
plt.show()


디퓨전 모델 학습 방법 설계
스테이블 디퓨전의 학습 목표는 역확산을 통해 잡음을 제거하여 원하는 데이터를 생성하는 방법을 배우는 것이다. 그러므로, 잡음 제거를 목표로 하여 샘플의 잡음을 제거하도록 신경망을 훈련한다. 참고로, 여기서는 앞서 설명한 전확산, 역확산 과정을 입력 데이터 형식과 학습 성능을 고려해, 수정된 확산 모델을 사용하기로 한다. 

이를 위해, 다음 노이즈 제거 목적 방정식을 정의한다.
여기서, p0(x0)는 목표의 분포(예. 고양이 이미지), x(noised)는 순방향 확산 후 목표 분포 x0의 샘플을 의미한다. 즉, [x(노이즈) - x0)는 정규 분포 확률 변수가 된다. 

이를 구현 가능한 방식으로 표현하면 다음과 같다. 
노이즈 제거 목적 방정식

이기서, J는 노이즈 제거 목표, E는 기대치, t는 시간 매개변수, x0는 목표 분포 p0(x0)의 샘플, x(noise)는 한 단계 순확산 후 목표 분포 샘플 x0, s()는 score 함수, σ(t)는 시간 함수, ε는 정규확률분포변수이다. 

학습 목표는 확산 과정의 모든 시간 t와 원래 분포(예. 고양이, 가족 이미지 등)의 모든 x0에 대한 샘플에 추가되는 노이즈 양을 효과적으로 예측하는 것이다. 

score 함수를 손실함수로 사용한다는 의미는 무작위 노이즈를 의미 있는 데이터 형태로 변환하는 과정을 학습한다는 뜻이 된다. 이를 위해, 신경망을 이용해 score 함수를 근사화한다. 

확산 모델 score 함수를 고려한 손실함수 구현은 순서 상 이 글 마지막에 다룬다.

시간 임베딩
score 함수는 시간에 따라 정확히 신경망이 반응하도록 구현해야 한다. 이를 위해, 시간 임베딩을 사용해 학습시킨다. 

시간 임베딩은 트랜스포머의 위치 임베딩과 유사한 정현파를 사용해 시간 특징을 계산할 수 있다. 다양한 시간 표현을 학습 시 입력함으로써 시간 변화에 따라 확산 과정을 학습시키도록 한다. 이를 통해, 시간종속적인 s(x, t)를 손실함수의 일부로 학습시킨다. 이를 위해, 다음 두 개 클래스를 만든다. 

1. 가우스 랜덤 기능 모듈 
이 모듈은 학습 컨텐스트에서 시간 단계를 표현하는 데 사용된다. 이를 이용하면, 각 시간단계 전반에 걸쳐 임의의 주파수가 생성된다. 이를 위해, 시간단계 별 sine, cosine 투영하여 시간 패턴 특징을 계산할 수 있다. 이 결과는 신경망 학습에 입력된다. 

class GaussianFourierProjection(nn.Module):  # 시간 특징 계산 클래스
    def __init__(self, embed_dim, scale=30.): # 임베딩 차원, 랜덤 가중치(주파수)를 위한 스케일 변수
        super().__init__()

        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)  # 랜덤 샘플링. 훈련 파라메터 아님.

    def forward(self, x): # 매 시간단위 텐서 입력
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi  # Cosine(2 pi freq x), Sine(2 pi freq x)
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)  # 최종 sine, cosine 결과 벡터 결합

2. 특징 텐서 계산
입력을 4D 특징 텐서로 출력하는 모듈이다. 이 차원 재구성 작업은 다음 컨볼루션 계층 처리에 적합한 맵으로 변환하기 위함이다. 
class Dense(nn.Module): # 특징 계산 클래스
    def __init__(self, input_dim, output_dim):  # 입력 특징 차원, 출력 특징 차원
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim) # 선형 완전연결 레이어
    def forward(self, x): # 입력 텐서
        return self.dense(x)[..., None, None]  # 학습 후 4D텐서

U-Net 구현
이미지를 다룰 때, 이미지 특징을 캡쳐하기 위해, U-Net 네트워크를 사용한다. U-Net 은 다른 공간 스케일에 대한 이미지 특징을 학습하는 것에 초점을 둔다. 

U-Net는 시간에 따라 어떻게 데이터가 변경되는 지를 학습해야 한다. 이 모델은 이 패턴을 학습한다. U-Net의 인코딩 경로는 이미지 해상도 다운샘플링, 특징 캡쳐를 위한 컨볼루션 레이어인 h1, h2, h3, h4로 구성된다. 디코딩 경로는 transpose 컨볼루션 레이어이며, 텐서 h가 h4에서 h1 레이어를 통과하여, 업샘플링된다. 

class UNet(nn.Module): # U-Net 모델 정의. nn.Module 클래스 상속
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):  # marginal_prob_std: 시간 t에 대한 표준편차 반환 함수, channels: 각 해상도의 특징 맵의 채널 수, embed_dim: 가우시안 랜덤 특징 임베딩의 차원
        super().__init__()

        # 시간에 대한 가우시안 랜덤 특징 임베딩 계층
        self.time_embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )

        # 인코딩 레이어 (해상도 감소)
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        # 추가 인코딩 레이어 (원본 코드에서 복사)
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # 해상도 증가를 위한 디코딩 레이어 
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)

        # 스위치 활성화 함수
        self.act = lambda x: x * torch.sigmoid(x)
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t, y=None):    # U-Net 아키텍처를 통과한 출력 텐서 반환. x: 입력 텐서, t: 시간 텐서, y: 타겟 텐서 (이 전방 통과에서 사용되지 않음). h: U-Net 아키텍처를 통과한 출력 텐서
        # t에 대한 가우시안 랜덤 특징 임베딩 획득
        embed = self.act(self.time_embed(t))

        # 인코딩 경로
        h1 = self.conv1(x) + self.dense1(embed)
        h1 = self.act(self.gnorm1(h1))
        h2 = self.conv2(h1) + self.dense2(embed)
        h2 = self.act(self.gnorm2(h2))

        # 추가 인코딩 경로 레이어 (원본 코드에서 복사)
        h3 = self.conv3(h2) + self.dense3(embed)
        h3 = self.act(self.gnorm3(h3))
        h4 = self.conv4(h3) + self.dense4(embed)
        h4 = self.act(self.gnorm4(h4))

        # 디코딩 경로
        h = self.tconv4(h4)
        h += self.dense5(embed)
        h = self.act(self.tgnorm4(h))
        h = self.tconv3(h + h3)
        h += self.dense6(embed)
        h = self.act(self.tgnorm3(h))
        h = self.tconv2(h + h2)
        h += self.dense7(embed)
        h = self.act(self.tgnorm2(h))
        h = self.tconv1(h + h1)

        # 정규화 출력
        h = h / self.marginal_prob_std(t)[:, None, None, None]
        return h

포워드 디퓨전 프로세스
전방향 확산 모델을 효과적으로 계산하기 위해, 다음 방정식을 고려한다. 

이 방정식은 변수 x에 대한 변화가 시간 t에 대한 노이즈가 dw에 에 비례하는 방식으로 동작한다. 노이즈 수준은 파라메터 σ에 결정되며, 지수적으로 증가한다. 

이 프로세스는 처음 x(0) 초기값이 주어지고, x(t)에 대한 분석적인 솔류션을 찾을 수 있다.  

이 모델에서, σ(t)는 표준 편차(marginal standard deviation)로 사용되며, x(t)는 분산의 변동으로 사용된다. σ(t)은 다음과 같이 계산된다. 

이 모델은 시간에 따라 어떻게 잡음 수준 σ가 정해지는 지 이해를 제공한다. 다음은 이를 코딩한 것이다. 
device = "cuda"

def marginal_prob_std(t, sigma): # 시간 t에 대한 표준편차 반환 함수
    t = torch.tensor(t, device=device) # 시간 텐서    
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma)) # 시간 t에 대한 표준편차 반환

def diffusion_coeff(t, sigma): # 확산 계수 함수. t: 시간 텐서, sigma: SDE의 시그마
    return torch.tensor(sigma**t, device=device) # 확산 계수 반환

sigma =  25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

손실 함수 구현
U-Net 구현에 앞서, 언급되었던 score 함수를 학습할 수 있도록, loss 함수를 구현한다.

이 함수는 임의의 시간단위를 샘플링하고, 여기서 잡음 수준을 얻은 후, 이 잡음을 데이터와 더한다. 그리고, 실제 데이터와 에측 데이터간의 오차를 계산하여, 에러를 줄이를 방향으로 학습하도록 한다.

def loss_fn(model, x, marginal_prob_std, eps=1e-5): # 손실함수. score-based generative models 훈련용. model: 시간 의존 스코어 모델, x: 훈련 데이터 미니배치, marginal_prob_std: 표준편차 반환 함수, eps: 수치 안정성을 위한 허용값
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - 2 * eps) + eps # 미니배치 크기만큼 랜덤 시간 샘플링

    std = marginal_prob_std(random_t)  # 랜덤 시간에 대한 표준편차 계산
    z = torch.randn_like(x)                   # 미니배치 크기만큼 정규 분포 랜덤 노이즈 생성    
    perturbed_x = x + z * std[:, None, None, None] # 노이즈로 입력 데이터 왜곡

    score = model(perturbed_x, random_t) # 왜곡 데이터, 시간 입력, 계산된 스코어 획득    
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3))) # score 함수와 잡음에 기반한 손실값 계산
    
    return loss

샘플러 코딩
Stable Diffusion는 임의 시점에서 이미지를 생성한다. 노이즈 예측기(noise predictor)는 얼마나 이미지에 노이즈를 줄 것인지를 예측한다. 이 예측된 노이즈는 이미지로 부터 제거된다. 전체 사이클은 몇 번 반복되어 깨끗한 이미지를 생성하게 된다. 

이 clearning-up 프로세스를 '샘플링'으로 알려져 있고, 매 학습 단계마다 새로운 이미지를 생성되도록 한다. 이를 샘플러, 샘플링 방법이라 한다.

스테이블 디퓨전은 이미지 샘플링을 생성하기 위한 다양한 방법이 있다. 여기서는 Euler-Maruyama 방법론을 사용한다.

# number of steps
num_steps = 500

def Euler_Maruyama_sampler(score_model,
                           marginal_prob_std,
                           diffusion_coeff,
                           batch_size=64,
                           x_shape=(1, 28, 28),
                           num_steps=num_steps,
                           device='cuda',
                           eps=1e-3, y=None): # Euler-Maruyama sampler 함수. score-based 모델을 사용해 샘플 생성. score_model: 시간 의존 스코어 모델, marginal_prob_std: 표준편차 반환 함수, diffusion_coeff: 확산 계수 함수, batch_size: 한번 호출시 생성할 샘플러 수, x_shape: 샘플 형태, num_steps: 샘플링 단계 수, device: 'cuda' 또는 'cpu', eps: 수치 안정성을 위한 허용값, y: 타겟 텐서 (이 함수에서 사용되지 않음) 

    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, *x_shape, device=device) * marginal_prob_std(t)[:, None, None, None]  # Initial sample
    
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1] # Step size 시리즈
    x = init_x # 시간 t에 대한 초기 샘플
    
    with torch.no_grad(): # Euler-Maruyama 샘플링
        for time_step in tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
    
    return mean_x

U-Net 학습 및 결과 시각화
앞서 구현된 U-Net 모델을 학습한다. 학습은 50 에폭, 2024 미니배치크기로 진행하며, 데이터셋은 MNIST를 사용한다. 학습된 모델은 chpt.pth 로 저장된다. 

batch_size = 2048
transform = transforms.Compose([
    transforms.ToTensor()  # Convert image to PyTorch tensor
    # transforms.Normalize((0.5,), (0.5,))  # Normalize images with mean and std
])
dataset = torchvision.datasets.MNIST('.', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 50
lr = 5e-4

optimizer = Adam(score_model.parameters(), lr=lr)

tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
    avg_loss = 0.
    num_items = 0
    for x, y in tqdm(data_loader):
        x = x.to(device)
        loss = loss_fn(score_model, x, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    torch.save(score_model.state_dict(), 'ckpt.pth')

학습된 결과를 다음과 같이 시각화해본다. 샘플러는 앞서 정의한 Euler_Maruyama_sampler이며, 이 함수는 앞서 U-Net으로 학습된 score_function을 호출해 그 결과를 리턴한다(샘플링 결과). 결과를 보면 알겠지만, 노이즈 입력에 대한 필기체가 출력된 것을 확인할 수 있다. 
sample_batch_size = 64
num_steps = 500

sampler = Euler_Maruyama_sampler  # Euler-Maruyama sampler 사용
samples = sampler(score_model, marginal_prob_std_fn,
                  diffusion_coeff_fn, sample_batch_size,
                  num_steps=num_steps, device=device, y=None)
samples = samples.clamp(0.0, 1.0)

import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.figure(figsize=(6, 6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()    

이제 텍스트(여기서는 숫자)를 입력하면, 이미지가 생성되도록 어텐션 레이어를 추가해 본다.

멀티모달 처리를 위한 어텐션 레이어와 트랜스포머 모듈 개발
트랜스포머 아키텍처의 어텐션 레이어는 입력에 대한 cross attention, spatial transformer (공간 트랜스포머)를 구현해야 한다.

어텐션 모델은 QKV(Query, Key, Value) 벡터로 학습된 컨텍스트 KV에 대한 질의 Q에 대한 스코어를 계산한다. 이런 특징은 멀티모달 생성AI를 가능하게 하는 핵심이 된다(예. KV는 이미지 학습, Q는 음성 입력). 이 벡터 QKV는 다음과 같이 차원 공간에서 기저벡터 e로 표현될 수 있다. 

핵심은 Q와 KV의 유사도 계산한 상호 벡터간 거리가 가까워지는 방향으로 학습 모델을 설계하는 것에 있다(상세한 설명은 글 앞의 트랜스포머 설명 링크를 참고). 여기서는 Q가 MNIST의 각 숫자 번호가 될 것이다. 참고로, 셀프 어텐션은 입력 토큰 내 관계를 학습하며, 교차 어텐션의 경우 입력 토큰과 컨텍스트 특징 간의 관계를 학습한다. 이 계산 결과로 컨텍스트 벡터를 리턴한다. 

어텐션은 torch.einsum 함수를 사용해(참고 - Einstein summation 함수), 다음과 같이 코딩될 수 있다.
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1): # 임베딩 차원, 은닉 차원, 컨텍스트 차원(self attention이면 None), 어텐션 헤드 수
        super(CrossAttention, self).__init__()

        self.hidden_dim = hidden_dim
        self.context_dim = context_dim
        self.embed_dim = embed_dim

        self.query = nn.Linear(hidden_dim, embed_dim, bias=False)  # Query에 대한 학습을 위한 선형 레이어
        
        if context_dim is None:
            self.self_attn = True
            self.key = nn.Linear(hidden_dim, embed_dim, bias=False)
            self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
        else:
            self.self_attn = False
            self.key = nn.Linear(context_dim, embed_dim, bias=False)
            self.value = nn.Linear(context_dim, hidden_dim, bias=False)

    def forward(self, tokens, context=None): # 토큰들[배치, 시퀀스 크기, 은닉 차원], 컨텍스트 정보[배치, 컨텍스트 시퀀스 크기, 컨텍스트 차원]. self_attn이 True면 컨텍스트는 무시됨
        if self.self_attn: # Self-attention case
            Q = self.query(tokens)
            K = self.key(tokens)
            V = self.value(tokens)
        else: # Cross-attention case
            Q = self.query(tokens)
            K = self.key(context)
            V = self.value(context)

        # Compute score matrices, attention matrices, and context vectors
        scoremats = torch.einsum("BTH,BSH->BTS", Q, K)  # Q, K간 내적 계산. 스코어 행렬 획득
        attnmats = F.softmax(scoremats / math.sqrt(self.embed_dim), dim=-1) # 스코어 행렬의 softmax 계산
        ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, V)  # 어텐션 행렬 적용된 V벡터 계산 
        return ctx_vecs

앞의 어텐션 레이어를 포함한 트랜스포머 모듈을 개발한다. 

이 트랜스포머 모듈은 교차 어텐션, 피드포워드 신경망을 통합한다. 차원 형태는 [batch, sequence_len, hidden_dim] 입력 텐서, [batch, context_seq_len, context_dim]인 컨텍스트 텐서를 사용한다. 셀프-어텐션 및 크로스 어텐션 모듈 처리 다음에 레이어 정규화, 잔차 연결이 계산된다. 비선형 변환을 위한 GELU, MLP 레이어가 포함된다. 

class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim, context_dim):  # 은닉 차원, 컨텍스트 차원
        super(TransformerBlock, self).__init__()

        self.attn_self = CrossAttention(hidden_dim, hidden_dim)
        self.attn_cross = CrossAttention(hidden_dim, hidden_dim, context_dim)

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 3 * hidden_dim),
            nn.GELU(),
            nn.Linear(3 * hidden_dim, hidden_dim)
        ) # Feed forward neural network. 2개의 레이어로 구성. 첫번째 레이어는 3 * hidden_dim개의 은닉 유닛을 가지고 nn.GELU 비선형성 함수를 사용. 두번째 레이어는 hidden_dim개의 은닉 유닛을 가짐

    def forward(self, x, context=None): # x: 입력 텐서[배치, 시퀀스 크기, 은닉 차원], context: 컨텍스트 텐서[배치, 컨텍스트 시퀀스 크기, 컨텍스트 차원]
        x = self.attn_self(self.norm1(x)) + x # self-attention 적용 후 layer normalization과 잔차 연결 적용
        x = self.attn_cross(self.norm2(x), context=context) + x # cross-attention 적용 후 layer normalization과 잔차 연결 적용
        x = self.ffn(self.norm3(x)) + x # feed forward neural network 적용 후 layer normalization과 잔차 연결 적용

        return x

이제, 공간 트랜스포머 모듈로 앞의 모듈을 통합한다. 

class SpatialTransformer(nn.Module):
    def __init__(self, hidden_dim, context_dim):
        super(SpatialTransformer, self).__init__()
        
        self.transformer = TransformerBlock(hidden_dim, context_dim)

    def forward(self, x, context=None): # x: 입력 텐서[배치, 채널, 높이, 너비], context: 컨텍스트 텐서[배치, 컨텍스트 시퀀스 크기, 컨텍스트 차원]
        b, c, h, w = x.shape
        x_in = x

        x = rearrange(x, "b c h w -> b (h w) c") # 입력 텐서의 차원을 재배열
        x = self.transformer(x, context) # 트랜스포머 블록 적용
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) # 텐서의 차원을 원래대로 복원

        return x + x_in # 공간 변환기의 출력과 입력의 잔차 연결

U-Net과 공간 트랜스포머 통합
앞서 설명된 U-Net과 공간 트랜스포머를 통합해, 특정 조건에 따라(예. 입력 텍스트, 숫자 등), 이미지가 생성되도록 U-Net 학습 모델을 일부 수정한다. 적색으로 표시된 부분이 공간 트랜스포머로 파라메터(예. 텍스트, 숫자 등) 컨디셔닝한 부분이다. 

class UNet_Tranformer(nn.Module): # 시간 의존된 스코어 기반 U-NET 모델
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
                 text_dim=256, nClass=10): # marginal_prob_std: 시간 t에 대한 표준편차 반환 함수, channels: 각 해상도의 특징 맵의 채널 수, embed_dim: 가우시안 랜덤 특징 임베딩의 차원, text_dim: 텍스트/숫자의 임베딩 차원, nClass: 모델링할 클래스 수
        super().__init__()

        # 시간에 대한 가우시안 랜덤 특징 임베딩 계층
        self.time_embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )

        # 인코딩 레이어 (해상도 감소)
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])

        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.attn3 = SpatialTransformer(channels[2], text_dim)  # 컨텍스트 정보, 텍스트 임베딩 차원을 공간 트랜스포머에 설정

        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
        self.attn4 = SpatialTransformer(channels[3], text_dim)

        # 디코딩 레이어. 해상도 증가
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])

        self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])

        self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)

        # The swish activation function
        self.act = nn.SiLU()
        self.marginal_prob_std = marginal_prob_std
        self.cond_embed = nn.Embedding(nClass, text_dim)

    def forward(self, x, t, y=None): # U-Net 아키텍처를 통과한 출력 텐서 반환. x: 입력 텐서, t: 시간 텐서, y: 타겟 텐서 (텍스트 토큰. 예. MNIST 번호). h: U-Net 아키텍처를 통과한 출력 텐서
        embed = self.act(self.time_embed(t))
        y_embed = self.cond_embed(y).unsqueeze(1)

        # Encoding path
        h1 = self.conv1(x) + self.dense1(embed)
        h1 = self.act(self.gnorm1(h1))
        h2 = self.conv2(h1) + self.dense2(embed)
        h2 = self.act(self.gnorm2(h2))
        h3 = self.conv3(h2) + self.dense3(embed)
        h3 = self.act(self.gnorm3(h3))
        h3 = self.attn3(h3, y_embed)
        h4 = self.conv4(h3) + self.dense4(embed)
        h4 = self.act(self.gnorm4(h4))
        h4 = self.attn4(h4, y_embed)

        # Decoding path
        h = self.tconv4(h4) + self.dense5(embed)
        h = self.act(self.tgnorm4(h))
        h = self.tconv3(h + h3) + self.dense6(embed)
        h = self.act(self.tgnorm3(h))
        h = self.tconv2(h + h2) + self.dense7(embed)
        h = self.act(self.tgnorm2(h))
        h = self.tconv1(h + h1)

        # Normalize output
        h = h / self.marginal_prob_std(t)[:, None, None, None]
        return h
        
스테이블 디퓨전 손실 함수 수정 및 최종 학습
이제 손실 함수도 이에 따라 수정한다. 
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5): # model: 시간 의존된 스코어 기반 모델, x: 입력 데이터 미니배치, y: 조건 정보(타겟 텐서. 예. 입력 텍스트, 숫자), marginal_prob_std: 표준편차 반환 함수, eps: 수치 안정성을 위한 허용값
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps # 미니배치 크기만큼 랜덤 시간 샘플링
    z = torch.randn_like(x) # 미니배치 크기만큼 정규 분포 랜덤 노이즈 생성
    std = marginal_prob_std(random_t) # 랜덤 시간에 대한 표준편차 계산
    perturbed_x = x + z * std[:, None, None, None] # 노이즈로 입력 데이터 왜곡

    score = model(perturbed_x, random_t, y=y) # 모델을 사용해 왜곡된 데이터와 시간에 대한 스코어 획득. 트랜스포머 어텐션에 Q로 입력되는 Y벡터 추가됨.
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3))) # score 함수와 잡음에 기반한 손실 계산
    return loss

그리고, 앞서 동일한 방식으로, 다음과 같이 학습한다. 앞의 손실 함수와 다른점은 y(MNIST 필기체 숫자. 적색표시)가 파라메터화된 조건으로 트랜스포머 어텐션 모델에 입력되었다는 것이다.
score_model = torch.nn.DataParallel(UNet_Tranformer(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs = 100   
batch_size = 1024 
lr = 10e-4        

transform = transforms.Compose([
    transforms.ToTensor()  # Convert image to PyTorch tensor
])
dataset = torchvision.datasets.MNIST('.', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98 ** epoch))

tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
    avg_loss = 0.
    num_items = 0

    for x, y in tqdm(data_loader):
        x = x.to(device)

        loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]

    scheduler.step()
    lr_current = scheduler.get_last_lr()[0]

    print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))

    torch.save(score_model.state_dict(), 'ckpt_transformer.pth')

스테이블 디퓨전 학습 과정

학습된 스테이블 디퓨전 모델 기반 생성AI 테스트
이제 학습된 모델을 이용해, 각 입력조건(예. MNIST 필기체 숫자)에 따라 적절한 이미지를 생성하는 지를 확인한다. 
ckpt = torch.load('ckpt_transformer.pth', map_location=device)
score_model.load_state_dict(ckpt)

digit = 9 # 생성AI 입력 조건. 0~9까지 숫자 생성 가능

sample_batch_size = 64 
num_steps = 250
sampler = Euler_Maruyama_sampler
# score_model.eval()

samples = sampler(score_model,
        marginal_prob_std_fn,
        diffusion_coeff_fn,
        sample_batch_size,
        num_steps=num_steps,
        device=device,
        y=digit*torch.ones(sample_batch_size, dtype=torch.long))

# 생성 결과 확인
samples = samples.clamp(0.0, 1.0)
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))

plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()

결과와 같이, 제대로 입력 조건에 따라 이미지가 생성되는 것을 확인할 수 있다. 이로써, 스테이블 디퓨전 모델이 어떻게 멀티모달 조건에서 생성할 데이터를 학습하는 지를 확인할 수 있다.
멀티모달 입력에 대한 생성 결과('9' 입력 > 생성 이미지)

마무리
이 글에서는 멀티모달리티를 구현하는 생성AI의 아키텍처인 스테이블 디퓨전의 구현 방법을 확인하고, 실행 과정을 확인함으로써, 생성AI의 동작원리를 좀 더 깊게 살펴보았다. 

독일 뮌헨 대학에서 학습한 규모의 이미지 데이터량을 학습하려면, 개인이 하기에는 비싼 비용과 시간이 필요하므로, 이 글에서는 MNIST와 같은 소형 데이터셋을 대상으로 학습하여, 학습되는 파라메터 수를 GPU 2GB 내에서 계산될 수 있도록 하였다. 

스테이블 디퓨전은 기존에 개발되었던 오토임베딩, 잠재공간표현, U-Net, ResNet, 파라메터 컨디셔닝, 멀티 모달, CLIP, 디퓨전, 트랜스포머 어텐션 모델이 모두 적용된 멀티모달 Text To Image 모델이다. 

현재 시점에서는 멀티모달에 대한 핵심 컴포넌트가 개발되어 성능이 이미 검증되었기 때문에, 앞으로도 OpenAI SORA와 같은 Text-To-Video와 같은 Multi-Modal Gen AI 모델은 더욱 다양하게 개발될 것이라 예상된다. 
OpenAI Text To Video Gen AI - SORA

레퍼런스

댓글 없음:

댓글 쓰기