モデルデータを読込 その2

Androidアプリでモデルデータを読み込む
今回はアヤメの判定を行う学習モデルを作成し、そのモデルデータを読み込んでテストを行う
データセットについては以下から入手する

$ curl -O https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

学習モデルのプログラム

# coding:utf-8
# tensorflow version1.0.0
import numpy as np
import tensorflow as tf

### データの準備
# データセットの読み込み
dataset = np.genfromtxt("./bezdekIris.data", delimiter=',', dtype=[float, float, float, float, "S32"])
# データセットの順序をランダムに並べ替える
np.random.shuffle(dataset)

def get_labels(dataset):
    """ラベル(正解データ)を1ofKベクトルに変換する"""
    raw_labels = [item[4] for item in dataset]
    labels = []
    for l in raw_labels:
        if l == "Iris-setosa":
            labels.append([1.0,0.0,0.0])
        elif l == "Iris-versicolor":
            labels.append([0.0,1.0,0.0])
        elif l == "Iris-virginica":
            labels.append([0.0,0.0,1.0])
    return np.array(labels)

def get_data(dataset):
    """データセットをnparrayに変換する"""
    raw_data = [list(item)[:4] for item in dataset]
    return np.array(raw_data)

# ラベル
labels = get_labels(dataset)
# データ
data = get_data(dataset)
# 訓練データとテストデータに分割する
# 訓練用データ
train_labels = labels[:120]
train_data = data[:120]
# テスト用データ
test_labels = labels[120:]
test_data = data[120:]


### モデルをTensor形式で実装

# ラベルを格納するPlaceholder
t = tf.placeholder(tf.float32, shape=(None,3))
# 入力データを格納するPlaceholder nameをつけておく
X = tf.placeholder(tf.float32, shape=(None,4),name="input")

# 隠れ層のノード数
node_num = 1024
w_hidden = tf.Variable(tf.truncated_normal([4,node_num]))
b_hidden = tf.Variable(tf.zeros([node_num]))
f_hidden = tf.matmul(X, w_hidden) + b_hidden
hidden_layer = tf.nn.relu(f_hidden)

# 出力層
w_output = tf.Variable(tf.zeros([node_num,3]))
b_output = tf.Variable(tf.zeros([3]))
f_output = tf.matmul(hidden_layer, w_output) + b_output
#出力結果の値が入る nameをつけておく
p = tf.nn.softmax(f_output,name="output")

# 交差エントロピー
cross_entropy = t * tf.log(p)
# 誤差関数
loss = -tf.reduce_mean(cross_entropy)
# トレーニングアルゴリズム
# 勾配降下法 学習率0.001
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train_step = optimizer.minimize(loss)
# モデルの予測と正解が一致しているか調べる
correct_pred = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
# モデルの精度
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

saver = tf.train.Saver()
run_metadata = tf.RunMetadata()
### 学習の実行
with tf.Session() as sess:
    #ディレクトリは事前に作成しておく
    ckpt = tf.train.get_checkpoint_state('./ckpt-iris')
    if ckpt:
      # checkpointファイルから最後に保存したモデルへのパスを取得する
      last_model = ckpt.model_checkpoint_path
      print("load {0}".format(last_model))
      # 学習済みモデルを読み込む
      saver.restore(sess, last_model)
    else:
      print("initialization")
    # ログの設定
    tf.summary.histogram("Hidden_layer_wights", w_hidden)
    tf.summary.histogram("Hidden_layer_biases", b_hidden)
    tf.summary.histogram("Output_layer_wights", w_output)
    tf.summary.histogram("Output_layer_wights", b_output)
    tf.summary.scalar("Accuracy", accuracy)
    tf.summary.scalar("Loss", loss)
    summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter("./iris_cassification_log", sess.graph)
    #初期化
    sess.run(tf.global_variables_initializer())

    i = 0
    for _ in range(5000):
        i += 1
        # トレーニング
        sess.run(train_step, feed_dict={X:train_data,
                                        t:train_labels})
        # 200ステップごとに精度を出力
        if i % 200 == 0:
            # コストと精度を出力
            train_summary,train_loss, train_acc = sess.run([summary,loss,accuracy], feed_dict={X:train_data,t:train_labels})
            writer.add_summary(train_summary,i)
            print "Step: %d" % i
            print "[Train] cost: %f, acc: %f" % (train_loss, train_acc)

    saver.save(sess, "iris-model")
sess.close()

モデルデータ(pbファイル)を作成する

# coding:utf-8
# tensorflow version1.0.0
import tensorflow as tf
from tensorflow.python.framework import graph_util

#モデルデータを作成する
def freeze_graph(model_folder):
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    output_graph = model_folder + "/frozen_model.pb"
    print(output_graph)

    output_node_names = "output,input"
    #学習時に計算にcpuやgpuの指定を行なっていた時、読み込む側でその指定に依存しないようにする
    clear_devices = True
    #グラフをインポートする
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
      #保存されている重みやバイアスの変数を復元する
        saver.restore(sess, input_checkpoint)

        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(",")
        )
        #モデルデータを書き出す
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

モデルデータをAndroidで読み込みテストする

pbファイルの置き場所についてはモデルデータを読込の記事を参照

import android.content.DialogInterface;
import android.content.res.AssetManager;
import android.support.v7.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;
//導入については開発環境構築(Android)を参照
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class MainActivity extends AppCompatActivity {
    //判定結果を表示するテキストビュー
    private TextView ansView;
    //入力データを入れるテキストフォーム
    private EditText editIrisFeature1;
    private EditText editIrisFeature2;
    private EditText editIrisFeature3;
    private EditText editIrisFeature4;
    //判定を開始するボタン
    private Button detectButton;
    static {
        System.loadLibrary("tensorflow_inference");
    }

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        ansView = (TextView) findViewById(R.id.answer_text);

        editIrisFeature1 = (EditText) findViewById(R.id.edit_Iris_feature1);
        editIrisFeature2 = (EditText) findViewById(R.id.edit_Iris_feature2);
        editIrisFeature3 = (EditText) findViewById(R.id.edit_Iris_feature3);
        editIrisFeature4 = (EditText) findViewById(R.id.edit_Iris_feature4);

        detectButton = (Button) findViewById(R.id.detect_Button);
        //ボタンを押した時のイベント
        detectButton.setOnClickListener(new View.OnClickListener(){
            @Override
            public void onClick(View v) {
               //入力データを格納する配列で判定に使う
                float[] features = new float[4];
                //入力のテキストフォームから得る値を取得する
                String value1 = null;
                String value2 = null;
                String value3 = null;
                String value4 = null;
                Boolean flag = false;
                try{
                    value1 = editIrisFeature1.getText().toString();
                    value2 = editIrisFeature2.getText().toString();
                    value3 = editIrisFeature3.getText().toString();
                    value4 = editIrisFeature4.getText().toString();
                    features[0] = Float.valueOf(value1);
                    features[1] = Float.valueOf(value2);
                    features[2] = Float.valueOf(value3);
                    features[3] = Float.valueOf(value4);
                }catch (java.lang.NumberFormatException e){
                  //未入力のまま開始した時エラーダイアログを表示するようにする
                    flag = true;
                    ansView.setText("");
                    AlertDialog.Builder alertDialog = new AlertDialog.Builder(MainActivity.this);

                    alertDialog.setTitle("input error");
                    alertDialog.setMessage("Please input values");
                    alertDialog.setPositiveButton("OK", new DialogInterface.OnClickListener() {
                        public void onClick(DialogInterface dialog,int which) {
                        }
                    });

                    alertDialog.create();
                    alertDialog.show();
                }

                if(!flag){
                  //判定を開始する
                    onDetectClicked(features);
                }

            }
        });


    }
    private void onDetectClicked(float[] f) {
        TensorFlowInferenceInterface mTensorFlowIF = new TensorFlowInferenceInterface();
        AssetManager mAssetManager = getAssets();
        //モデルデータ読み込み
        int result = mTensorFlowIF.initializeTensorFlow(mAssetManager, "file:///android_asset/frozen_model.pb");

        ansView.setText("");
        //入力データを入れる
        //第一引数にはモデル作成時に指定した入力データを格納するPlaceholderのnameを指定する
        //第二引数にはshapeを指定する 今回は一つのサンプルデータを渡してテストするので(1,4)
        //第三引数は入力データ
        mTensorFlowIF.fillNodeFloat("input:0",new int[] {1,4},f);
        //判定結果を格納する配列
        float[] result_value = new float[3];
        //判定を行うこの時にモデル作成時の出力のnameを指定する
        mTensorFlowIF.runInference(new String[] {"output:0"});
        //result_valueに結果を代入する
        mTensorFlowIF.readNodeFloat("output:0", result_value);

        //result_valueには[a,b,c]と値が入っており、一番大きな値が入ってる配列のインデックスが入力データのクラスとなる
        int ansIndex = getAnswer(result_value);
        switch (ansIndex){
            case 0:
                ansView.setText("Detected : Iris-Setosa");
                break;

            case 1:
                ansView.setText("Detected : Iris-versicolor");
                break;

            case 2:
                ansView.setText("Detected : Iris-virginica");
                break;
        }
    }
    //一番大きな値が入っているインデックスを返す
    private int getAnswer(float[] f){
        int argmax = 0;
        float max = f[0];
        for(int i=0;i<f.length;i++){
            if(max < f[i]){
                max = f[i];
                argmax = i;
            }
        }

        return argmax;
    }
}

activity_main.xml

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:id="@+id/activity_main"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:paddingBottom="@dimen/activity_vertical_margin"
    android:paddingLeft="@dimen/activity_horizontal_margin"
    android:paddingRight="@dimen/activity_horizontal_margin"
    android:paddingTop="@dimen/activity_vertical_margin"
    tools:context="各自のproject名">

    <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
        android:orientation="horizontal"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:paddingTop="16dp">
        <EditText
            android:id="@+id/edit_Iris_feature1"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
        <EditText
            android:id="@+id/edit_Iris_feature2"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
        <EditText
            android:id="@+id/edit_Iris_feature3"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
        <EditText
            android:id="@+id/edit_Iris_feature4"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
    </LinearLayout>

    <Button
        android:id="@+id/detect_Button"
        android:text="Button"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_centerHorizontal="true"
        android:layout_marginTop="70dp" />

    <TextView
        android:id="@+id/answer_text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_below="@+id/detect_Button"
        android:textSize="30sp"
        android:layout_margin="30dp"
        android:gravity="center" />

</RelativeLayout>

build.gradle androidの所に下記を追加しておく

sourceSets {
    main {
        jniLibs.srcDirs = ['libs']
        assets.srcDirs = ['assets']
    }
}


データセットからサンプルを選び、正解ラベルが返ってきたら成功