エンジニア

なんくるないさ

「このブログはアフィリエイト広告を利用しています」

NNCを使ってリアルタイムに推論する

リアルタイムで推論するコードのイメージは以下 ただし注意点がいくつか

  • pキーを押したら画像を保存してその画像の名前が書かれたcsvを用いて推論するようになっている(リアルタイムにするためにはif文を消す)
  • 使うネットワークとパラメータの書き換えが必要

ただ推論するだけなら毎回画像を保存している処理はいらないですね
そのまま画像を渡すようにすれば良いはずです

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
from nnabla.utils.data_iterator import data_iterator_csv_dataset 
import os 
import cv2
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# 吐き出したネットワーク
def network(x, y, test=False):
    # Input:x -> 3,128,128
    # Convolution -> 16,126,126
    h = PF.convolution(x, 16, (3,3), (0,0), name='Convolution')
    # MaxPooling -> 16,63,63
    h = F.max_pooling(h, (2,2), (2,2))
    # Tanh
    h = F.tanh(h)
    # Convolution_2 -> 8,59,62
    h = PF.convolution(h, 8, (5,2), (0,0), name='Convolution_2')
    # MaxPooling_2 -> 8,29,31
    h = F.max_pooling(h, (2,2), (2,2))
    # Tanh_2
    h = F.tanh(h)
    # Affine -> 10
    h = PF.affine(h, (10,), name='Affine')
    # Tanh_3
    h = F.tanh(h)
    # Affine_2 -> 3
    h = PF.affine(h, (3,), name='Affine_2')
    # Softmax
    h = F.softmax(h)
    # CategoricalCrossEntropy -> 1
    #h = F.categorical_cross_entropy(h, y)
    return h

cap = cv2.VideoCapture(0) # 任意のカメラ番号に変更する

new_dir_path = "./realtime/"
os.makedirs(new_dir_path, exist_ok=True)


#カメラスタート
while True:
    ret, frame = cap.read()
    cv2.imshow("camera", frame)

    k = cv2.waitKey(1)&0xff # キー入力を待つ
    if k == ord('p'):
 # 「p」キーで画像を保存
        date = datetime.now().strftime("%Y%m%d_%H%M%S")
        path = new_dir_path + date +".png"
        cv2.imwrite(path, frame)       
        image_gs = cv2.imread(path)

        path = new_dir_path + date +".png"
        dst = cv2.resize(image_gs,(128,128))
        cv2.imwrite(path, dst)

        f = pd.DataFrame(columns=["x:data","y:data"])
        xdata = path
        ydata = 0
        new_name = pd.Series([xdata,ydata],index=f.columns)
        f = f.append(new_name, ignore_index=True)
        f.to_csv('valu.csv',index=False,header = True )

        test_data = data_iterator_csv_dataset("./valu.csv",1,shuffle=False,normalize=True) 


        #ネットワークの構築
        nn.clear_parameters()
        x = nn.Variable((1,3,128,128))
        t = nn.Variable((1,1))
        y = network(x, t)

        nn.load_parameters('./results.nnp')
        print("load model")


            
        path = new_dir_path + "test" +".png"
        cv2.imwrite(path, frame)       
        image_gs = cv2.imread(path)

        path = new_dir_path + date +".png"
        dst = cv2.resize(image_gs,(128,128))
        cv2.imwrite(path, dst)

        f = pd.DataFrame(columns=["x:data","y:data"])
        xdata = path
        ydata = 0
        new_name = pd.Series([xdata,ydata],index=f.columns)
        f = f.append(new_name, ignore_index=True)
        f.to_csv('valu.csv',index=False,header = True )

        test_data = data_iterator_csv_dataset("./valu.csv",1,shuffle=False,normalize=True) 


        #ネットワークの構築
        nn.clear_parameters()
        x = nn.Variable((1,3,128,128))
        t = nn.Variable((1,1))
        y = network(x, t)

        nn.load_parameters('./results.nnp')
        print("load model")
           for i in range(test_data.size):
            x.d, t.d = test_data.next()
            y.forward()
            print(y.d[0])   
                   
            
        elif k == ord('q'):
            # 「q」キーが押されたら終了する
            break

# キャプチャをリリースして、ウィンドウをすべて閉じる
cap.release()
cv2.destroyAllWindows()