2022년 8월 20일 토요일

간단한 Mask R-CNN 기반 객체 세그먼테이션 딥러닝 학습 및 예측 프로그램 만들기

이 글은 Mask R-CNN 기반 객체 세그먼테이션 딥러닝 학습 및 예측 프로그램 개발 방법을 간략히 설명한다. 객체 감지 및 인스턴스 분할은 이미지에서 객체를 식별하고 분할하는 작업이다. 여기에는 각 객체에 대한 경계 상자, 정확한 객체를 덮는 마스크 및 객체 클래스를 찾는 작업이 포함된다. Mask R-CNN 은 이를 달성하기 위한 가장 일반적인 방법 중 하나이다. 이 글은 Mask R-CNN의 이론적 내용을 자세히 설명하지는 않는다(자세한 내용은 레퍼런스를 참고한다).
세그먼테이션된 객체들
Fast R-CNN 아키텍처 개념도

개발 환경 설정
개발을 위해, 파이썬 및 파이토치 개발 환경을 마련한 후(참고), 다음 라이브러리를 설치한다. 
pip install opencv-python

구현하기
다음과 같이, 토치비전에서 R-CNN 라이브러리를 불러오고, 초기화한다.
import random
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import numpy as np
import torch.utils.data
import cv2
import torchvision.models.segmentation
import torch
import os
batchSize=2
imageSize=[600,600]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

다음과 같이 이미지 경로를 변수에 설정한다. 
trainDir="LabPicsChemistry/Train"

imgs=[]
for pth in os.listdir(trainDir):
    imgs.append(trainDir+"/"+pth +"//")

해당 이미지는 다음과 같이 특정 경로에 마스크 이미지가 저장되어 있다고 가정한다(참고 - 연구실 학습용 데이터셋 다운로드 링크). 앞의 trainDir 변수 경로에 맞게, 학습용 데이터셋 폴더 경로가 설정되어 있도록 한다.

저장된 데이터들을 로딩한 후, 학습에 맞게 변환하는 코드를 구현한다.
def loadData():
  batch_Imgs=[]
  batch_Data=[]
  for i in range(batchSize):
        idx=random.randint(0,len(imgs)-1)
        img = cv2.imread(os.path.join(imgs[idx], "Image.jpg"))
        img = cv2.resize(img, imageSize, cv2.INTER_LINEAR)    # 이미지 크기 변환
        maskDir=os.path.join(imgs[idx], "Vessels")
        masks=[]
        for mskName in os.listdir(maskDir):
            vesMask = cv2.imread(maskDir+'/'+mskName, 0)
            vesMask = (vesMask > 0).astype(np.uint8) 
            vesMask=cv2.resize(vesMask,imageSize,cv2.INTER_NEAREST)
            masks.append(vesMask)
        num_objs = len(masks)
        if num_objs==0: return loadData()
        boxes = torch.zeros([num_objs,4], dtype=torch.float32)
        for i in range(num_objs):
            x,y,w,h = cv2.boundingRect(masks[i])
            boxes[i] = torch.tensor([x, y, x+w, y+h])
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        img = torch.as_tensor(img, dtype=torch.float32)    # 텐서 데이터로 변환
        data = {}
        data["boxes"] =  boxes
        data["labels"] =  torch.ones((num_objs,), dtype=torch.int64)   
        data["masks"] = masks
        batch_Imgs.append(img)
        batch_Data.append(data)  
  
  batch_Imgs=torch.stack([torch.as_tensor(d) for d in batch_Imgs],0)
  batch_Imgs = batch_Imgs.swapaxes(1, 3).swapaxes(2, 3)
  return batch_Imags, batch_Data

다음과 같이, 딥러닝 학습 코드를 구현한다.
model=torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)  # 전이학습

for i in range(10001):
   images, targets = loadData()
   images = list(image.to(device) for image in images)
   targets=[{k: v.to(device) for k,v in t.items()} for t in targets]
   
   optimizer.zero_grad()
   loss_dict = model(images, targets)   # 모델 손실 계산
   losses = sum(loss for loss in loss_dict.values())
   
   losses.backward()    # 역전파   
   optimizer.step()       # 신경망 가중치 업데이트
   
   print(i,'loss:', losses.item())
   if i%200==0:
           torch.save(model.state_dict(), str(i)+".torch")
           print("Save model to:",str(i)+".torch")

학습 후, 이미지 데이터를 모델에 입력해, 예측해 본다. 
images = cv2.imread(imgPath)  # 특정 경로 이미지 로딩
images = cv2.resize(images, imageSize, cv2.INTER_LINEAR)
images = torch.as_tensor(images, dtype=torch.float32).unsqueeze(0)
images=images.swapaxes(1, 3).swapaxes(2, 3)
images = list(image.to(device) for image in images)

with torch.no_grad():
    pred = model(images)  # 모델 예측. 세그먼테이션

예측 스코어가 0.8 이상 세그먼트만 출력해 본다.
im= images[0].swapaxes(0, 2).swapaxes(0, 1).detach().cpu().numpy().astype(np.uint8)
im2 = im.copy()
for i in range(len(pred[0]['masks'])):
    msk=pred[0]['masks'][i,0].detach().cpu().numpy()
    scr=pred[0]['scores'][i].detach().cpu().numpy()
    if scr>0.8 :
        im2[:,:,0][msk>0.5] = random.randint(0,255)
        im2[:, :, 1][msk > 0.5] = random.randint(0,255)
        im2[:, :, 2][msk > 0.5] = random.randint(0, 255)
cv2.imshow(str(scr), np.hstack([im,im2]))
cv2.waitKey()

다음은 예측 결과이다.

마무리
최근 이미지, 텍스트, 사운드 기반 분류, 예측, 패턴 인식, 재구성과 같은 널리 알려진 문제에 대한 솔류션은 PyTorch, Keras 등에 내장되어 쉽게 학습된 모델을 다운로드하고, 사용할 수 있게 되었다. 이러한 기능을 잘 활용한다면, 인공지능을 이용한 서비스를 좀 더 쉽게 개발할 수 있을 것이다.

레퍼런스