PyTorch와 Transfer Learning으로 꽃 이미지 분류기 만들기
PyTorch와 Transfer Learning으로 꽃 이미지 분류기 만들기
오늘은 PyTorch와 Transfer Learning 기법을 활용하여 꽃 이미지를 분류하는 모델을 만들어보겠습니다. 다양한 꽃 사진 데이터셋을 대상으로 하며, 사전 훈련된 모델의 지식을 활용하여 빠르고 정확한 모델을 구축할 수 있습니다. 전체 과정을 단계별로 따라가시면서 실습해보시기 바랍니다.
사전 준비
PyTorch와 관련 패키지를 설치합니다.
pip install torch torchvision
그리고 실습에 사용할 데이터셋을 다운로드 받습니다. 이번 예제에서는 아래 링크의 꽃 이미지 데이터셋을 활용하겠습니다.
https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
데이터셋을 다운로드하여 적절한 폴더에 압축을 해제해주세요.
데이터 전처리
import torch
import torchvision
from torchvision import transforms, datasets
# 데이터 전처리
data_dir = './flowers'
train_tfms = transforms.Compose([transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
valid_tfms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
train_data = datasets.ImageFolder(data_dir+'/train', transform=train_tfms)
valid_data = datasets.ImageFolder(data_dir+'/valid', transform=valid_tfms)
# 데이터로더
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
위 코드는 꽃 이미지 데이터셋을 PyTorch에서 사용할 수 있도록 전처리합니다. 데이터 증강을 위해 Resize, RandomHorizontalFlip, Normalization 등의 변환을 적용하였습니다. 그리고 batch_size 32로 DataLoader를 생성하였습니다.
사전 훈련된 모델 로드
PyTorch에서는 torchvision.models에 다양한 사전 훈련 모델이 포함되어 있습니다. 이번에는 ResNet-18 모델을 활용해보겠습니다.
import torchvision.models as models
model = models.resnet18(pretrained=True)
# 마지막 완전연결층(Fully Connected Layer)을 새로 정의
n_inputs = model.fc.in_features
n_outputs = len(train_data.classes)
model.fc = torch.nn.Linear(n_inputs, n_outputs)
pretrained=True로 지정하여 ImageNet 데이터셋에서 미리 학습된 ResNet-18 모델 가중치를 로드합니다. 그리고 분류 작업에 맞게 마지막 완전연결층을 새로 정의합니다. 이렇게 하면 기존 모델 지식을 활용하면서도 새로운 꽃 데이터셋에 맞게 마지막 층을 재학습 시킬 수 있습니다.
모델 미세조정(Fine-tuning)
import torch.optim as optim
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
num_epochs = 30
for epoch in range(num_epochs):
# 학습 모드로 설정
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 검증 모드로 설정
model.eval()
valid_loss = 0.0
accuracy = 0
with torch.no_grad():
for inputs, labels in valid_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
valid_loss += loss.item()
prob = torch.exp(outputs)
top_p, top_class = prob.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
train_loss = running_loss / len(train_loader)
valid_loss = valid_loss / len(valid_loader)
print(f'Epoch: {epoch+1:02}')
print(f'\tTrain Loss: {train_loss:.3f}')
print(f'\tValid Loss: {valid_loss:.3f}')
print(f'\tAccuracy: {accuracy/len(valid_loader):.3f}')
이 코드는 꽃 데이터셋에서 ResNet-18 모델을 미세조정(Fine-tuning)하는 과정입니다. GPU가 있다면 GPU를 사용하도록 설정하였습니다.
CrossEntropyLoss를 손실 함수로, SGD를 옵티마이저로 설정했습니다. 30 에포크 동안 학습하면서 매 에포크마다 학습/검증 손실과 정확도를 출력합니다.
모델 평가 및 배포
30 epoch 학습이 완료되면 최종 모델의 성능을 평가합니다.
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
test_accuracy = 0.0
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
prob = torch.exp(outputs)
top_p, top_class = prob.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
test_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
print(f'Test Accuracy: {test_accuracy/len(test_loader):.3f}')
위 코드로 테스트 데이터셋에 대한 모델의 최종 정확도를 계산하였습니다.
마지막으로 학습된 모델을 배포하기 위해 ONNX 포맷으로 변환하고 웹/모바일 앱에 서빙할 수 있습니다.
dummy_input = torch.randn(1, 3, 224, 224)
torchscript_model = torch.jit.trace(model, dummy_input)
# ONNX 모델로 변환
import torch.onnx
torch.onnx.export(torchscript_model, dummy_input, "flower_classifier.onnx")
ONNX 변환 후에는 웹 프레임워크(Flask, FastAPI 등)나 모바일 앱에서 ONNX 런타임을 사용하여 모델 추론을 수행할 수 있습니다.
지금까지 PyTorch와 Transfer Learning 기법을 활용해 꽃 이미지 분류 모델을 만들어보는 전체 과정을 살펴보았습니다. 실제 프로젝트에서 활용해보시면서 이해를 높여보시기 바랍니다.