U-Net 구현
이미지를 다룰 때, 이미지 특징을 캡쳐하기 위해, U-Net 네트워크를 사용한다. U-Net 은 다른 공간 스케일에 대한 이미지 특징을 학습하는 것에 초점을 둔다.
U-Net는 시간에 따라 어떻게 데이터가 변경되는 지를 학습해야 한다. 이 모델은 이 패턴을 학습한다. U-Net의 인코딩 경로는 이미지 해상도 다운샘플링, 특징 캡쳐를 위한 컨볼루션 레이어인 h1, h2, h3, h4로 구성된다. 디코딩 경로는 transpose 컨볼루션 레이어이며, 텐서 h가 h4에서 h1 레이어를 통과하여, 업샘플링된다.
class UNet(nn.Module): # U-Net 모델 정의. nn.Module 클래스 상속
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256): # marginal_prob_std: 시간 t에 대한 표준편차 반환 함수, channels: 각 해상도의 특징 맵의 채널 수, embed_dim: 가우시안 랜덤 특징 임베딩의 차원
super().__init__()
# 시간에 대한 가우시안 랜덤 특징 임베딩 계층
self.time_embed = nn.Sequential(
GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim)
)
# 인코딩 레이어 (해상도 감소)
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
# 추가 인코딩 레이어 (원본 코드에서 복사)
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# 해상도 증가를 위한 디코딩 레이어
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)
# 스위치 활성화 함수
self.act = lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t, y=None): # U-Net 아키텍처를 통과한 출력 텐서 반환. x: 입력 텐서, t: 시간 텐서, y: 타겟 텐서 (이 전방 통과에서 사용되지 않음). h: U-Net 아키텍처를 통과한 출력 텐서
# t에 대한 가우시안 랜덤 특징 임베딩 획득
embed = self.act(self.time_embed(t))
# 인코딩 경로
h1 = self.conv1(x) + self.dense1(embed)
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
# 추가 인코딩 경로 레이어 (원본 코드에서 복사)
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
# 디코딩 경로
h = self.tconv4(h4)
h += self.dense5(embed)
h = self.act(self.tgnorm4(h))
h = self.tconv3(h + h3)
h += self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2)
h += self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# 정규화 출력
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
포워드 디퓨전 프로세스
전방향 확산 모델을 효과적으로 계산하기 위해, 다음 방정식을 고려한다.
이 방정식은 변수 x에 대한 변화가 시간 t에 대한 노이즈가 dw에 에 비례하는 방식으로 동작한다. 노이즈 수준은 파라메터 σ에 결정되며, 지수적으로 증가한다.
이 프로세스는 처음 x(0) 초기값이 주어지고, x(t)에 대한 분석적인 솔류션을 찾을 수 있다.
이 모델에서, σ(t)는 표준 편차(marginal standard deviation)로 사용되며, x(t)는 분산의 변동으로 사용된다. σ(t)은 다음과 같이 계산된다.
이 모델은 시간에 따라 어떻게 잡음 수준 σ가 정해지는 지 이해를 제공한다. 다음은 이를 코딩한 것이다.
device = "cuda"
def marginal_prob_std(t, sigma): # 시간 t에 대한 표준편차 반환 함수
t = torch.tensor(t, device=device) # 시간 텐서
return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma)) # 시간 t에 대한 표준편차 반환
def diffusion_coeff(t, sigma): # 확산 계수 함수. t: 시간 텐서, sigma: SDE의 시그마
return torch.tensor(sigma**t, device=device) # 확산 계수 반환
sigma = 25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
손실 함수 구현
U-Net 구현에 앞서, 언급되었던 score 함수를 학습할 수 있도록, loss 함수를 구현한다.
이 함수는 임의의 시간단위를 샘플링하고, 여기서 잡음 수준을 얻은 후, 이 잡음을 데이터와 더한다. 그리고, 실제 데이터와 에측 데이터간의 오차를 계산하여, 에러를 줄이를 방향으로 학습하도록 한다.
def loss_fn(model, x, marginal_prob_std, eps=1e-5): # 손실함수. score-based generative models 훈련용. model: 시간 의존 스코어 모델, x: 훈련 데이터 미니배치, marginal_prob_std: 표준편차 반환 함수, eps: 수치 안정성을 위한 허용값
random_t = torch.rand(x.shape[0], device=x.device) * (1. - 2 * eps) + eps # 미니배치 크기만큼 랜덤 시간 샘플링
std = marginal_prob_std(random_t) # 랜덤 시간에 대한 표준편차 계산
z = torch.randn_like(x) # 미니배치 크기만큼 정규 분포 랜덤 노이즈 생성
perturbed_x = x + z * std[:, None, None, None] # 노이즈로 입력 데이터 왜곡
score = model(perturbed_x, random_t) # 왜곡 데이터, 시간 입력, 계산된 스코어 획득
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3))) # score 함수와 잡음에 기반한 손실값 계산
return loss
샘플러 코딩
Stable Diffusion는 임의 시점에서 이미지를 생성한다. 노이즈 예측기(noise predictor)는 얼마나 이미지에 노이즈를 줄 것인지를 예측한다. 이 예측된 노이즈는 이미지로 부터 제거된다. 전체 사이클은 몇 번 반복되어 깨끗한 이미지를 생성하게 된다.
이 clearning-up 프로세스를 '샘플링'으로 알려져 있고, 매 학습 단계마다 새로운 이미지를 생성되도록 한다. 이를 샘플러, 샘플링 방법이라 한다.
스테이블 디퓨전은 이미지 샘플링을 생성하기 위한 다양한 방법이 있다. 여기서는 Euler-Maruyama 방법론을 사용한다.
# number of steps
num_steps = 500
def Euler_Maruyama_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
x_shape=(1, 28, 28),
num_steps=num_steps,
device='cuda',
eps=1e-3, y=None): # Euler-Maruyama sampler 함수. score-based 모델을 사용해 샘플 생성. score_model: 시간 의존 스코어 모델, marginal_prob_std: 표준편차 반환 함수, diffusion_coeff: 확산 계수 함수, batch_size: 한번 호출시 생성할 샘플러 수, x_shape: 샘플 형태, num_steps: 샘플링 단계 수, device: 'cuda' 또는 'cpu', eps: 수치 안정성을 위한 허용값, y: 타겟 텐서 (이 함수에서 사용되지 않음)
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, *x_shape, device=device) * marginal_prob_std(t)[:, None, None, None] # Initial sample
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1] # Step size 시리즈
x = init_x # 시간 t에 대한 초기 샘플
with torch.no_grad(): # Euler-Maruyama 샘플링
for time_step in tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step)
mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
return mean_x
U-Net 학습 및 결과 시각화
앞서 구현된 U-Net 모델을 학습한다. 학습은 50 에폭, 2024 미니배치크기로 진행하며, 데이터셋은 MNIST를 사용한다. 학습된 모델은 chpt.pth 로 저장된다.
batch_size = 2048
transform = transforms.Compose([
transforms.ToTensor() # Convert image to PyTorch tensor
# transforms.Normalize((0.5,), (0.5,)) # Normalize images with mean and std
])
dataset = torchvision.datasets.MNIST('.', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
n_epochs = 50
lr = 5e-4
optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x, y in tqdm(data_loader):
x = x.to(device)
loss = loss_fn(score_model, x, marginal_prob_std_fn)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
torch.save(score_model.state_dict(), 'ckpt.pth')
학습된 결과를 다음과 같이 시각화해본다. 샘플러는 앞서 정의한 Euler_Maruyama_sampler이며, 이 함수는 앞서 U-Net으로 학습된 score_function을 호출해 그 결과를 리턴한다(샘플링 결과). 결과를 보면 알겠지만, 노이즈 입력에 대한 필기체가 출력된 것을 확인할 수 있다.
sample_batch_size = 64
num_steps = 500
sampler = Euler_Maruyama_sampler # Euler-Maruyama sampler 사용
samples = sampler(score_model, marginal_prob_std_fn,
diffusion_coeff_fn, sample_batch_size,
num_steps=num_steps, device=device, y=None)
samples = samples.clamp(0.0, 1.0)
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.figure(figsize=(6, 6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()
이제 텍스트(여기서는 숫자)를 입력하면, 이미지가 생성되도록 어텐션 레이어를 추가해 본다.
멀티모달 처리를 위한 어텐션 레이어와 트랜스포머 모듈 개발
트랜스포머 아키텍처의 어텐션 레이어는 입력에 대한 cross attention, spatial transformer (공간 트랜스포머)를 구현해야 한다.
어텐션 모델은 QKV(Query, Key, Value) 벡터로 학습된 컨텍스트 KV에 대한 질의 Q에 대한 스코어를 계산한다. 이런 특징은 멀티모달 생성AI를 가능하게 하는 핵심이 된다(예. KV는 이미지 학습, Q는 음성 입력). 이 벡터 QKV는 다음과 같이 차원 공간에서 기저벡터 e로 표현될 수 있다.
핵심은 Q와 KV의 유사도 계산한 상호 벡터간 거리가 가까워지는 방향으로 학습 모델을 설계하는 것에 있다(상세한 설명은 글 앞의 트랜스포머 설명 링크를 참고). 여기서는 Q가 MNIST의 각 숫자 번호가 될 것이다. 참고로, 셀프 어텐션은 입력 토큰 내 관계를 학습하며, 교차 어텐션의 경우 입력 토큰과 컨텍스트 특징 간의 관계를 학습한다. 이 계산 결과로 컨텍스트 벡터를 리턴한다.
class CrossAttention(nn.Module):
def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1): # 임베딩 차원, 은닉 차원, 컨텍스트 차원(self attention이면 None), 어텐션 헤드 수
super(CrossAttention, self).__init__()
self.hidden_dim = hidden_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.query = nn.Linear(hidden_dim, embed_dim, bias=False) # Query에 대한 학습을 위한 선형 레이어
if context_dim is None:
self.self_attn = True
self.key = nn.Linear(hidden_dim, embed_dim, bias=False)
self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
else:
self.self_attn = False
self.key = nn.Linear(context_dim, embed_dim, bias=False)
self.value = nn.Linear(context_dim, hidden_dim, bias=False)
def forward(self, tokens, context=None): # 토큰들[배치, 시퀀스 크기, 은닉 차원], 컨텍스트 정보[배치, 컨텍스트 시퀀스 크기, 컨텍스트 차원]. self_attn이 True면 컨텍스트는 무시됨
if self.self_attn: # Self-attention case
Q = self.query(tokens)
K = self.key(tokens)
V = self.value(tokens)
else: # Cross-attention case
Q = self.query(tokens)
K = self.key(context)
V = self.value(context)
# Compute score matrices, attention matrices, and context vectors
scoremats = torch.einsum("BTH,BSH->BTS", Q, K) # Q, K간 내적 계산. 스코어 행렬 획득
attnmats = F.softmax(scoremats / math.sqrt(self.embed_dim), dim=-1) # 스코어 행렬의 softmax 계산
ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, V) # 어텐션 행렬 적용된 V벡터 계산
return ctx_vecs
앞의 어텐션 레이어를 포함한 트랜스포머 모듈을 개발한다.
이 트랜스포머 모듈은 교차 어텐션, 피드포워드 신경망을 통합한다. 차원 형태는 [batch, sequence_len, hidden_dim] 입력 텐서, [batch, context_seq_len, context_dim]인 컨텍스트 텐서를 사용한다. 셀프-어텐션 및 크로스 어텐션 모듈 처리 다음에 레이어 정규화, 잔차 연결이 계산된다. 비선형 변환을 위한 GELU, MLP 레이어가 포함된다.
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim, context_dim): # 은닉 차원, 컨텍스트 차원
super(TransformerBlock, self).__init__()
self.attn_self = CrossAttention(hidden_dim, hidden_dim)
self.attn_cross = CrossAttention(hidden_dim, hidden_dim, context_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.norm3 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 3 * hidden_dim),
nn.GELU(),
nn.Linear(3 * hidden_dim, hidden_dim)
) # Feed forward neural network. 2개의 레이어로 구성. 첫번째 레이어는 3 * hidden_dim개의 은닉 유닛을 가지고 nn.GELU 비선형성 함수를 사용. 두번째 레이어는 hidden_dim개의 은닉 유닛을 가짐
def forward(self, x, context=None): # x: 입력 텐서[배치, 시퀀스 크기, 은닉 차원], context: 컨텍스트 텐서[배치, 컨텍스트 시퀀스 크기, 컨텍스트 차원]
x = self.attn_self(self.norm1(x)) + x # self-attention 적용 후 layer normalization과 잔차 연결 적용
x = self.attn_cross(self.norm2(x), context=context) + x # cross-attention 적용 후 layer normalization과 잔차 연결 적용
x = self.ffn(self.norm3(x)) + x # feed forward neural network 적용 후 layer normalization과 잔차 연결 적용
return x
이제, 공간 트랜스포머 모듈로 앞의 모듈을 통합한다.
class SpatialTransformer(nn.Module):
def __init__(self, hidden_dim, context_dim):
super(SpatialTransformer, self).__init__()
self.transformer = TransformerBlock(hidden_dim, context_dim)
def forward(self, x, context=None): # x: 입력 텐서[배치, 채널, 높이, 너비], context: 컨텍스트 텐서[배치, 컨텍스트 시퀀스 크기, 컨텍스트 차원]
b, c, h, w = x.shape
x_in = x
x = rearrange(x, "b c h w -> b (h w) c") # 입력 텐서의 차원을 재배열
x = self.transformer(x, context) # 트랜스포머 블록 적용
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) # 텐서의 차원을 원래대로 복원
return x + x_in # 공간 변환기의 출력과 입력의 잔차 연결
U-Net과 공간 트랜스포머 통합
앞서 설명된 U-Net과 공간 트랜스포머를 통합해, 특정 조건에 따라(예. 입력 텍스트, 숫자 등), 이미지가 생성되도록 U-Net 학습 모델을 일부 수정한다. 적색으로 표시된 부분이 공간 트랜스포머로 파라메터(예. 텍스트, 숫자 등) 컨디셔닝한 부분이다.
class UNet_Tranformer(nn.Module): # 시간 의존된 스코어 기반 U-NET 모델
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
text_dim=256, nClass=10): # marginal_prob_std: 시간 t에 대한 표준편차 반환 함수, channels: 각 해상도의 특징 맵의 채널 수, embed_dim: 가우시안 랜덤 특징 임베딩의 차원, text_dim: 텍스트/숫자의 임베딩 차원, nClass: 모델링할 클래스 수
super().__init__()
# 시간에 대한 가우시안 랜덤 특징 임베딩 계층
self.time_embed = nn.Sequential(
GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim)
)
# 인코딩 레이어 (해상도 감소)
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.attn3 = SpatialTransformer(channels[2], text_dim) # 컨텍스트 정보, 텍스트 임베딩 차원을 공간 트랜스포머에 설정
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
self.attn4 = SpatialTransformer(channels[3], text_dim)
# 디코딩 레이어. 해상도 증가
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)
# The swish activation function
self.act = nn.SiLU()
self.marginal_prob_std = marginal_prob_std
self.cond_embed = nn.Embedding(nClass, text_dim)
def forward(self, x, t, y=None): # U-Net 아키텍처를 통과한 출력 텐서 반환. x: 입력 텐서, t: 시간 텐서, y: 타겟 텐서 (텍스트 토큰. 예. MNIST 번호). h: U-Net 아키텍처를 통과한 출력 텐서
embed = self.act(self.time_embed(t))
y_embed = self.cond_embed(y).unsqueeze(1)
# Encoding path
h1 = self.conv1(x) + self.dense1(embed)
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
h3 = self.attn3(h3, y_embed)
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
h4 = self.attn4(h4, y_embed)
# Decoding path
h = self.tconv4(h4) + self.dense5(embed)
h = self.act(self.tgnorm4(h))
h = self.tconv3(h + h3) + self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2) + self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
스테이블 디퓨전 손실 함수 수정 및 최종 학습
이제 손실 함수도 이에 따라 수정한다.
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5): # model: 시간 의존된 스코어 기반 모델, x: 입력 데이터 미니배치, y: 조건 정보(타겟 텐서. 예. 입력 텍스트, 숫자), marginal_prob_std: 표준편차 반환 함수, eps: 수치 안정성을 위한 허용값
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps # 미니배치 크기만큼 랜덤 시간 샘플링
z = torch.randn_like(x) # 미니배치 크기만큼 정규 분포 랜덤 노이즈 생성
std = marginal_prob_std(random_t) # 랜덤 시간에 대한 표준편차 계산
perturbed_x = x + z * std[:, None, None, None] # 노이즈로 입력 데이터 왜곡
score = model(perturbed_x, random_t, y=y) # 모델을 사용해 왜곡된 데이터와 시간에 대한 스코어 획득. 트랜스포머 어텐션에 Q로 입력되는 Y벡터 추가됨.
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3))) # score 함수와 잡음에 기반한 손실 계산
return loss
그리고, 앞서 동일한 방식으로, 다음과 같이 학습한다. 앞의 손실 함수와 다른점은 y(MNIST 필기체 숫자. 적색표시)가 파라메터화된 조건으로 트랜스포머 어텐션 모델에 입력되었다는 것이다.
score_model = torch.nn.DataParallel(UNet_Tranformer(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
n_epochs = 100
batch_size = 1024
lr = 10e-4
transform = transforms.Compose([
transforms.ToTensor() # Convert image to PyTorch tensor
])
dataset = torchvision.datasets.MNIST('.', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98 ** epoch))
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x, y in tqdm(data_loader):
x = x.to(device)
loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
scheduler.step()
lr_current = scheduler.get_last_lr()[0]
print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
torch.save(score_model.state_dict(), 'ckpt_transformer.pth')
스테이블 디퓨전 학습 과정
학습된 스테이블 디퓨전 모델 기반 생성AI 테스트
이제 학습된 모델을 이용해, 각 입력조건(예. MNIST 필기체 숫자)에 따라 적절한 이미지를 생성하는 지를 확인한다.
ckpt = torch.load('ckpt_transformer.pth', map_location=device)
score_model.load_state_dict(ckpt)
digit = 9 # 생성AI 입력 조건. 0~9까지 숫자 생성 가능
sample_batch_size = 64
num_steps = 250
sampler = Euler_Maruyama_sampler
# score_model.eval()
samples = sampler(score_model,
marginal_prob_std_fn,
diffusion_coeff_fn,
sample_batch_size,
num_steps=num_steps,
device=device,
y=digit*torch.ones(sample_batch_size, dtype=torch.long))
# 생성 결과 확인
samples = samples.clamp(0.0, 1.0)
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()
결과와 같이, 제대로 입력 조건에 따라 이미지가 생성되는 것을 확인할 수 있다. 이로써, 스테이블 디퓨전 모델이 어떻게 멀티모달 조건에서 생성할 데이터를 학습하는 지를 확인할 수 있다.
멀티모달 입력에 대한 생성 결과('9' 입력 > 생성 이미지)
마무리