ロジスティック回帰 Androidで動作

TensorFlow

Androidソースコード

package fabo.io.hellotensorflow;

import android.content.res.AssetManager;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class MainActivity extends AppCompatActivity {


    private final static String TAG = "TF_LOG";

    static {
        System.loadLibrary("tensorflow_inference");
    }

    @Override
    protected void onCreate (Bundle savedInstanceState){
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        TensorFlowInferenceInterface mTensorFlowIF = new TensorFlowInferenceInterface();
        AssetManager mAssetManager = getAssets();
        int result = mTensorFlowIF.initializeTensorFlow(mAssetManager, "file:///android_asset/graph-virus.pb");

        mTensorFlowIF.enableStatLogging(true);
        Log.i(TAG, "---------");
        Log.i(TAG, "initializeTensorFlow:result:" + result);

        float[] x_value = new float[2];
        x_value[0] = (float) 2.0;
        x_value[1] = (float) 2.0;

        mTensorFlowIF.fillNodeFloat("input",new int[] {0,2}, x_value);

        // Add
        mTensorFlowIF.runInference(new String[] {"add_op"});
        float[] result_value1 = new float[2];
        mTensorFlowIF.readNodeFloat("add_op", result_value1);
        Log.i(TAG, "result_add:  " + result_value1[0]);
        Log.i(TAG, "result_add:  " + result_value1[1]);

        // Predict
        mTensorFlowIF.runInference(new String[] {"predict_op200"});
        float[] result_value2 = new float[2];
        mTensorFlowIF.readNodeFloat("predict_op200", result_value2);
        Log.i(TAG, "result_predict:  " + result_value2[0]);
        Log.i(TAG, "result_predict:  " + result_value2[1]);
    }

}

predict_op200, app_opも処理が走るが値が変わらない。

package fabo.io.hellotensorflow;

import android.content.res.AssetManager;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class MainActivity extends AppCompatActivity {


    private final static String TAG = "TF_LOG";

    static {
        System.loadLibrary("tensorflow_inference");
    }

    @Override
    protected void onCreate (Bundle savedInstanceState){
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        TensorFlowInferenceInterface mTensorFlowIF = new TensorFlowInferenceInterface();
        AssetManager mAssetManager = getAssets();
        int result = mTensorFlowIF.initializeTensorFlow(mAssetManager, "file:///android_asset/graph-virus.pb");

        mTensorFlowIF.enableStatLogging(true);
        Log.i(TAG, "---------");
        Log.i(TAG, "initializeTensorFlow:result:" + result);

        float[] x_value = new float[2];
        x_value[0] = (float) 9.0;
        x_value[1] = (float) 3.0;

        mTensorFlowIF.fillNodeFloat("input",new int[] {2}, x_value);

        // Add
        mTensorFlowIF.runInference(new String[] {"add_op"});
        float[] result_value1 = new float[2];
        mTensorFlowIF.readNodeFloat("add_op", result_value1);
        Log.i(TAG, "result_add:  " + result_value1[0]);
        Log.i(TAG, "result_add:  " + result_value1[1]);

        // Predict
        int x_cols=2;
        int x_rows=10;
        float x_value2[] = {
                /* x_rows[0] */ -2,-2,
                /* x_rows[1] */ 0,0,
                /* x_rows[2] */ -2,2,
                /* x_rows[3] */ -2 /* error? */
                /* x_rows[4] */ /* error? */
                /* x_rows[...] */ /* error? */
        };
        mTensorFlowIF.fillNodeFloat("input",new int[] {x_rows,x_cols}, x_value2);
        int runInference = mTensorFlowIF.runInference(new String[]{"predict_op200"});

        float[] result_value2 = new float[x_rows];
        mTensorFlowIF.readNodeFloat("predict_op200", result_value2);

        // 入力値->出力を整形
        String x[]= new String[x_rows];
        String message = "";
            for (int row = 0; row < x_rows; row++) {
                message = "x(";

                for (int col = 0; col < x_cols; col++) {
                    x[col] = x_value2.length > (row * x_cols + col) ? Float.toString(x_value2[row * x_cols + col]) : "?";
                    if (col > 0) {
                        message += ",";
                    }
                    message += x[col];
                }
                message += ") -> result_predict:  " + result_value2[row];
                Log.i(TAG, message);
            }
        }
    }
}

出力

I/TF_LOG: result_add:  18.0
I/TF_LOG: result_add:  6.0

I/TF_LOG: x(-2.0,-2.0) -> result_predict:  1.0
I/TF_LOG: x(0.0,0.0) -> result_predict:  0.0
I/TF_LOG: x(-2.0,2.0) -> result_predict:  1.0
I/TF_LOG: x(-2.0,?) -> result_predict:  0.0
I/TF_LOG: x(?,?) -> result_predict:  0.0
I/TF_LOG: x(?,?) -> result_predict:  0.0
I/TF_LOG: x(?,?) -> result_predict:  0.0
I/TF_LOG: x(?,?) -> result_predict:  0.0
I/TF_LOG: x(?,?) -> result_predict:  0.0
I/TF_LOG: x(?,?) -> result_predict:  0.0