이전에는 FloydHub를 이용하여 모델을 학습하였다.
이제 학습된 모델을 TensorFlow Lite형식의 모델로 변환 후 Android에 올려 Image Classfication을 해보려고 한다.
https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/index.html#0
구글 코드랩에서 tflite로 변환하는 부분을 참고 하였고 구글링을 통해서 필요한 변환 과정, 스크립트, 도구 정보를 습득 하였습니다.
이 게시글에서 필요한 파일들은 모두 제 github에 올려놓았습니다.
현재 글과 해당 github에 있는 모델 구조가 상이하므로 참고해주세요.
Link : https://github.com/GyuminDev/CNN_Classification_N
필요한 파일 및 환경
전체 과정은 크게 보면 이렇다.
1. TensorFlow를 이용해 학습된 모델 파일 -> .pb
- .pb파일을 tensorboard를 이용해 시각화 해보며 Graph에서 input node, ouput node의 이름을 파악해야함.
2. .pb파일을 optimize 한다.
3. optimize한 .pb파일도 역시 tensorboard를 이용해서 graph의 변화가 없는지 check
4. 마지막으로 .pb파일을 .tflite형태로 변환한다.
5. .tflite파일을 Android Project의 올려서 Run!!
1. .pb파일 확인 및 Tensorboard를 이용하여 시각화
학습을 진행한 후 output을 보면 이렇다
1. step별 checkpoint파일
2. 모델 파일 - cifar.pb
3. tensorboard event파일 - train, validate 폴더로 구성되어 있음
먼저 Script/pb_tensorboard.py를 이용해서 학습된 모델인 cifar.pb를 시각화 해보자.
pb_tensorboard.py가 있는 위치로 이동 후 아래처럼 스크립트 실행을 한다.
model_dir=pb파일이 있는 경로
log_dir=tensorboard 이벤트파일을 저장할 경로
1 2 3 | python pb_tensorboard.py \ model_dir=./cifar10.pb \ log_dir=./pb_logs | cs |
이렇게 작성하면 pb_logs 하위에 tensorboard Event파일이 생성된다.
다음과 같이 cifar10.pb를 시각화하여 그래프의 형태를 볼 수 있다.
여기서 input과 output 노드의 이름을 메모해 놓아야 한다.
input node = input_node -> shape = [?, 32, 32, 3]
output = output -> shape = [?, 10]
input은 32x32 사이즈의 RGB(3채널)이미지를 받으며 Output으로는 softmax 함수를 취하여 10가지 Label(클래스)에 대한 예측점수가 나오게 된다.
중요한 점은 TensorFlow Code를 작성 할 때 Tensorboard 시각화가 잘되도록 scope 부분과 name을 잘 지정해 주어야 한다.
2. .pb파일을 optimize
1 2 3 4 5 | python optimize_for_inference.py \ --input=./cifar.pb \ --output=./opt_cifar10.pb \ --input_names="input_node" \ --output_names="output" | cs |
3. .pb convert to .tflite
1 2 3 4 5 6 7 8 9 10 11 | IMAGE_SIZE=32 toco \ --input_file=./opt_cifar10.pb \ --output_file=./cifar10.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --input_shape=1,${IMAGE_SIZE},${IMAGE_SIZE},3 \ --input_array=input_node \ --output_array=output \ --inference_type=FLOAT \ --input_type=FLOAT | cs |
4. example에 내가 만든 cifar10.tflite로 대체하기
'ML_DL' 카테고리의 다른 글
FloydHub 이용 TensorFlow CNN CIFAR10 예제 학습 (1) | 2018.04.09 |
---|