본문 바로가기
프로젝트 기록부

개인 프로젝트 04. - 강아지 VS 고양이 사진 분류 AI 웹 서비스 개발

by Amins 2023. 5. 18.

Preview

Abstract

GOAL : 강아지와 고양이를 분류하는 인공지능 웹 서비스 만들기
Benefit : 인공지능 서비스 배포의 전반적 흐름 파악 및 웹 프레임워크 숙달, 보유한 AI기술과 프론트엔드 기술의 융합 도모, 일반 유저에게 이미지에 대한 AI Inference 수준을 체감시켜 AI에 대한 인식을 재고
Task : Google Colab과 Tensorflow를 이용해 이진분류 AI를 학습하고, Flask 웹 프레임워크를 이용해 학습된 AI가 유저의 사진을 분석해 답해주는 웹 서비스를 배포 
Solution Method
- Tensorflow CatsvsDogs Dataset을 이용한 CNN 기반 AI 이진분류기 구현 및 학습, 학습 성능 향상을 위한 전이 학습 방식 사용, h5 포맷으로 모델 구조 및 학습된 파라미터 값 저장
- 개발 및 테스트를 위해 Anaconda를 이용한 colab환경과 동일한 로컬 가상환경 구축
- Flask 웹 프레임워크 내에 학습된 이진분류 모델을 불러와 프론트엔드와 연결
- 유저의 이미지 입력을 추론 모델에 넣기 위한 이미지 전처리 함수 제작
- HTML, CSS, JavaScript를 이용해 웹 서비스의 골격 및 스타일, 입력 이미지와 추론 결과를 띄워주는 함수 제작 
Tools
- Google Colab Pro(Jupyter Notebook)
- VScode(Flask, Python3, HTML, CSS, JavaScript, Jinja)
- Github
- Anaconda

Background

늘 인공지능을 수학하며, 내가 만든 모델을 단순히 Metric과 같은 수치로 평가하고 끝내는 것이 아니라, 실제 유저들이 이용할 수 있는 서비스로 연결해 보고 싶었다. 아무리 성능 좋은 알고리즘이라도, 서비스로 사용되지 않으면 의미가 없기 때문이다. 따라서 간단한 인공지능 모델을 학습시켜 웹 서비스를 구현해보고자 했다. 조사를 해 보니, 간단한 머신 러닝 모델 배포에 사용되는 웹 프레임워크로는 Flask가 가장 나았다. 가볍고, 같은 Python을 이용한 프레임워크다 보니 배우기도 용이했다. 이 프로젝트는 Flask도 숙달하고, 지금껏 배운 HTML/CSS/JS를 내가 가진 AI 기술과 접목하는데 많은 도움이 될 것이라 생각한다. 궁극적으로, 인공지능 모델 배포 워크 플로우를 체득하고, AI와 웹이 어떻게 상호작용하는지를 배워 AI 모델을 웹에 배포할 수 있는 능력을 함양하고자 한다.


Review

1. 이진분류기 학습

이진분류기 학습은 Tensorflow Dataset CatsVSDogs를 이용했다. 이미지 분류이며 문제의 난도가 크게 어렵지 않다고 판단하여 간단한 CNN 모델을 이용했다. 빠른 결과를 위해 10 epochs만 학습을 진행했고, 약 75%의 정확도를 얻었다. 해당 모델을 더 튜닝할 수도 있었지만, 기존에 학습된 모델을 가져와 파인튜닝 하는 것이 더 효율적이라고 생각해  Imagenet 데이터로 학습된 MobileNet을 불러왔다. 최상층 레이어는 이진분류 문제에 맞게 교체해 주었다. 

이후 모든 레이어를 얼리고, 최상층 레이어만 20epochs 가량 학습해 모델을 문제에 피팅시켰고, 이후 모든 레이어의 잠금을 풀고 매우 낮은 learning rate(1e-5)로 10 epochs 추가 학습해 파라미터 튜닝을 마쳤다. 이를 통해 최종 Test Accuracy 98%를 달성했다. 학습된 모델은 h5 포맷으로 저장했다. 아래는 분류기 학습에 이용된 코드이다.

https://github.com/haminse/AI_binary_classification_web/blob/main/cat_dog_Classifier.ipynb

 

GitHub - haminse/AI_binary_classification_web: AI_binary_classification_web_service_production

AI_binary_classification_web_service_production. Contribute to haminse/AI_binary_classification_web development by creating an account on GitHub.

github.com

 

2. Flask를 통한 웹 서비스 제작

가장 먼저, Anaconda를 이용해 로컬 가상환경을 구축했다. 내 노트북의 로컬 환경에서는 파이썬 3.11을 사용하고 있는데, 이 최신 버전의 파이썬이 텐서플로와 같은 라이브러리들과 호환이 잘 되지 않아 버전 충돌이 일어났기 때문이다. 모델 학습을 Colab환경에서 실시했기 때문에 Colab환경과 동일한 Python, 각종 프레임워크 및 라이브러리 버전을 맞추는 것이 중요했다.

이후 Flask를 이용해 웹 어플리케이션을 제작했다. Python기반의 프레임워크다 보니 Tensorflow 메서드들을 자유롭게 사용할 수 있었다. 아래는 이미지 전처리 함수, 모델 로드 및 할당, 라우팅(URL에 해당하는 HTML 및 동작을 할당) 등을 관장하는 Flask 코드이다.

 

from flask import Flask, render_template, request, redirect
import tensorflow as tf
# from tensorflow import keras
from tensorflow.keras.models import load_model
# from keras.preprocessing import image
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import tempfile
# import os

# loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# optimizer = tf.keras.optimizers.Adam()
# metrics = tf.keras.metrics.BinaryAccuracy()
# print('keras : ', keras.__version__)
app = Flask(__name__)
dic = {0 : 'Cat', 
       1 : 'Dog'}

model = load_model('model.h5')

def predict_label(image):
    # i = load_img(img_path, target_size = (122,122))
    # image = image.resize((224,224))
    i = img_to_array(image)/255.0
    i = i.reshape(1,224,224,3)
    p = model.predict(i)
    p = (p >= 0).astype(int)
    return dic[p[0][0]]


@app.route("/", methods = ['GET', 'POST'])
def index():
    return render_template('index.html', answer = None)

@app.route("/upload", methods = ['POST'])
def upload():
    img = request.files['image']

    # to not save the user input
    _, temp_file_path = tempfile.mkstemp() # when I dont want to save image but no display after upload also
    
    # to save the user input img
    # temp_file_path = 'static/' + img.filename

    img.save(temp_file_path)
    print(temp_file_path)

    # load the image from the file
    image = load_img(temp_file_path, target_size=(224, 224))
    #predict the label
    rst = predict_label(image)
    # pass the image path to the template
    # img_path = os.path.basename(temp_file_path)

    #delete the user image
    # os.remove(temp_file_path)

    return render_template('index.html', answer = f"It's a {rst}.")#, img_path = img_path)

@app.route("/info", methods = ['GET', 'POST'])
def info():
    return render_template('info.html')

# to debug in local
app.run(debug = True, port = 5003)

# to deploy
# if __name__ == "__main__":
#     app.run(host='0.0.0.0', port=80)

 

@app.route(url) 구문이 Flask에서의 라우팅을 다룬다. 해당 구문 내의 URL에서 아래의 함수가 실행되는 방식이다. 

Flask를 로컬에서 실행할 때는 로컬 환경에서의 port를 설정해 주어야 하는데, 나는 5000 ~  5002 포트가 다른 웹 앱에 할당되어 있어 5003으로 설정했다. 또 debug = True 파라미터를 통해 로컬에서 코드를 수정했을 때 브라우저에서 바로 바로 수정 사항을 확인할 수 있다. 배포 시에는 # to deploy 아래의 코드를 대신 활성화 한 뒤 실행해야 한다.

 

업로드 된 사진 보안 이슈

유저들이 이미지을 업로드하고, 이를 분석하는 방식의 웹 서비스라 업로드된 이미지를 저장하는 문제가 발생했다. 입력 받은 사진을 불러오기 위해 load_img() 모듈을 사용해 사진이 저장된 경로를 지정해 야했는데, 이 때문에 사진을 저장하는 것이 필수적이었다. 그러나 이렇게 모든 업로드된 사진들을 저장하자니 보안이나 서버 데이터베이스 메모리 문제가 예상되었다. 개인정보 보호법을 찾아보니 사진을 저장하려면 사전에 유저들의 동의를 받아야하는데, 이 과정에서 많은 유저들이 서비스에서 이탈될 것이라 느꼈다. 

이에 대한 해결책으로 tempfile 라이브러리를 사용했다. 가상의 임시 파일 경로를 만들어 주는 기능인데, 이를 통해 가상의 경로에 업로드된 사진을 임시 저장하고, 이미지 분류 후 해당 임시 파일을 삭제하는 로직을 구현했다. 이로서 유저들의 데이터를 여 타 메모리에 저장하는 문제를 방지하고, 유저들의 개인적인 사진을 보유함으로서 생기는 다양한 리스크들을 피할 수 있었다. 

 

미리보기 기능 구현

유저가 사진을 올리면, 어떤 사진을 올렸는지 바로 화면에 띄워주면 사용성이 증대될 것이라고 생각했다. 아래는 JavaScript를 이용한 이미지 미리보기 기능이다. 파일을 업로드하면 해당 함수가 실행되게 구현했다.

 

function previewImage(input) {
    if (input.files && input.files[0]) {
        let input_image_box = document.getElementById('preview');
        let upload_btn = document.getElementById('upload_btn');
        input_image_box.style.display = 'block'
        upload_btn.style.display = 'block'

        var reader = new FileReader();
        reader.onload = function(e) {
            input_image_box.setAttribute('src', e.target.result);
        }
        reader.readAsDataURL(input.files[0]);
    }
}

 

 

보안점

아직 서버를 구현하지 못해 로컬 환경에서만 작동하는 프로젝트이다. 추후 GCP나 AWS 서버를 구매하여 해당 웹 서비스를 배포 및 관리해 보고 싶다. 

 

전체 코드

https://github.com/haminse/AI_binary_classification_web

 

GitHub - haminse/AI_binary_classification_web: AI_binary_classification_web_service_production

AI_binary_classification_web_service_production. Contribute to haminse/AI_binary_classification_web development by creating an account on GitHub.

github.com

 

댓글