AI, Backend, Mobile

AI w aplikacji, czyli TensorFlow na Androidzie

Obecnie prawie każdy producent elektroniki twierdzi, że w jego sprzęcie mieszka potężny dżin zwany sztuczną inteligencją. Po smartfonach przychodzi pomału kolej na sprzęty AGD, takie jak pralka czy mikrofalówka. Choć oczywiście “AI” to marketingowy buzzword, bo póki co nie pogadamy z naszą lodówką zamawiając na wieczór schłodzone Martini (wstrząśnięte, nie zmieszane), to myślę, że warto przyjrzeć się bliżej technologiom pozwalającym na użycie sieci neuronowych na urządzeniach mobilnych.


Krzysztof Joachimiak. Machine Learning Engineer i Software Developer w StethoMe®. Zaczynał jako Android Developer, po czym przeskoczył do swojej ulubionej działki IT, czyli uczenia maszynowego. Stara się trzymać rękę na pulsie jeśli chodzi o rozwój tak zwanej – nieraz trochę na wyrost — sztucznej inteligencji. Ostatnio związany z poznańskim startupem z branży medycznej, StethoMe®.


Zdalnie czy lokalnie?

Jeśli piszemy aplikację na system Android, która ma korzystać z dobrodziejstw sztucznej inteligencji, stajemy przed wyborem — czy nasz model ma być uruchamiany na telefonie/tablecie, czy też ma funkcjonować jako web service odpytywany przez nasze urządzenie mobilne. Głównymi kryteriami przy podejmowaniu decyzji jest wielkość samego modelu, czas na wykonanie potrzebnych obliczeń, oraz to, czy model ma działać offline. W większości przypadków zdecydujemy się zapewne na zdalne API. Co jednak zrobić w przypadku, jeśli chcielibyśmy skorzystać z sieci neuronowej, która ma być uruchamiana bezpośrednio na urządzeniu z Androidem? Proponuję użyć dobrze znaną bibliotekę TensorFlow od Google.

Zanim zaczniemy

  1. Przygotuj odpowiednie środowisko. Osobiście pracowałem na Linuksie (Ubuntu 16.04) korzystając z Pythona 2.7.
  2. Zainstaluj TensorFlow.
  3. Zainstaluj najnowszą wersję narzędzia Bazel.
  4. Sklonuj repozytorium TensorFlow.
  5. Zbuduj narzędzia służące do transformowania i optymalizacji sieci
    1. bazel build sciezka_do_tensorflow/tools/graph_transforms:transform_graph
    2. bazel build sciezka_do_tensorflow/tools/graph_transforms:summarize_graph
    3. bazel build sciezka_do_tensorflow/contrib/lite/toco:toco

Eksportujemy naszą sieć

W tym tutorialu chodzi mi o pokazanie samego procesu eksportowania sieci, nie będę więc powielać treści mnóstwa tutoriali do TensorFlow i celowo pominę proces treningu sieci. Załóżmy, że będzie to banalnie prosta architektura, reprezentująca regresję logistyczną.

Podczas tworzenia modelu musimy pamiętać, żeby nazwać węzeł wejściowy oraz wyjściowy grafu; w przeciwnym wypadku będziemy zmuszeni do jego ręcznego przeglądania i znalezienia nazw nadanych automatycznie przez TensorFlow.

`` python
export_graph.py

import tensorflow as tf
import os

OUTPUT_DIR = "net"
OUTPUT_NAME = "model"
OUTPUT_SIZE = 1 # wielkość wyjścia

if __name__ == '__main__':

    # Definicja sieci
    X = tf.placeholder(tf.float32, (None, 10), name="input")
    network = tf.layers.dense(X, OUTPUT_SIZE)
    network = tf.nn.sigmoid(network)

    # Użycie funkcji liniowej to dobry sposób na nadanie nazwy ostatniemu węzłowi, jeśli 
    # np. wczytujemy skądś sieć. Tworzymy wówczas dodatkowy węzeł 
    network = tf.identity(network, name= "output")

    init = tf.global_variables_initializer()

with tf.Session() as sess: 
    # Inicjalizacja sieci losowymi wagami
     sess.run(init)

    # Zapisujemy graf
 tf.train.write_graph(sess.graph_def, '.',os.path.join(OUTPUT_DIR, "{}.pbtxt".format(OUTPUT_NAME)))
   
# Zapisujemy wagi modelu
saver = tf.train.Saver()
saver.save(sess, os.path.join(OUTPUT_DIR, OUTPUT_NAME))

Podczas tworzenia modelu w TensorFlow warto zwrócić uwagę na dwie rzeczy:

  • Jeśli wykorzystujemy warstwy konwolucyjne, bezpieczniej jest użyć domyślnej kolejności wymiarów z kanałami na końcu. Tylko ta wersja jest obecnie wspierana przez implementację TensoFlow działającą na CPU.
  • Normalizacja batcha wymaga ustawiania odpowiedniej flagi na etapie predykcji (flaga training).

Po uruchomieniu tego skryptu, otrzymamy folder zawierający pięć plików, w tym <OUTPUT_NAME>.pbtx, zawierający strukturę grafu obliczeniowego, oraz cztery pliki wyprodukowane przez Savera, zawierające wagi modelu.

Następnym krokiem jest połączenie struktury z wagami w jednym pliku. W tym celu możemy użyć narzędzia freeze_graph, które należy zbudować podobnie jak pozostałe narzędzia wymienione przeze mnie w punkcie 5., lub użyć poniższego skryptu:

``python
freeze.py


import sys
from tensorflow.python.tools import freeze_graph
import os

INPUT_DIR = sys.argv[1]
OUTPUT_DIR = sys.argv[2]

MODEL_NAME = sys.argv[3]
OUTPUT_NODE = sys.argv[4]


freeze_graph.freeze_graph(input_graph=os.path.join(INPUT_DIR, "{}.pbtxt".format(MODEL_NAME)),
                          input_saver="",
                          input_binary=False,
                          input_checkpoint=os.path.join(INPUT_DIR, MODEL_NAME),
                          output_node_names=OUTPUT_NODE,
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph=os.path.join(OUTPUT_DIR,"{}.pb".format(MODEL_NAME)),
                          clear_devices=True,
                          initializer_nodes="")


   ``

Jeśli zdecydujemy się na uruchomienie użycia skryptu, powinniśmy zastosować następującą komendę:

python freeze.py model model net output

W tym miejscu będziemy już mieć plik model.pb, który zawiera pełen model sieci — strukturę grafu obliczeniowego wraz z wagami. To jednak nie jest jeszcze ostatni krok. Powinniśmy dodatkowo zoptymalizować nasz graf, aby obliczenia były mniej zasobożerne (co zresztą jest szczególnie ważne na urządzeniu mobilnym!).

Optymalizacja sieci

W tym celu używamy uprzednio zbudowanego narzędzia graph_transform. Jako argumenty musimy przekazać nazwę pliku wejściowego (model.pb) oraz wyjściowego — dla odróżnienia nazwijmy go model_opt.pb. graph_transform oferuje szereg różnego rodzaju transformacji; ja polecam użyć przede wszystkim fold_constants, remove_device (bez tego graf będzie żądał odpalenia na konkretnym rodzaju urządzenia, np. na karcie graficznej, czego nie oferuje biblioteka TesnorFlow na Andoridzie). Pełną listę dostępnych transformacji można znaleźć tutaj. Wśród nich na pewno warto zwrócić uwagę na fold_batch_norm oraz fold_old_batch_norm używaną do optymalizacji obliczeń w warstwie normalizacji batcha.

```bash
input_file=model.pb
output_file=model_opt.pb
/home/some/path/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph 
--in_graph=$input_file 
--out_graph=$output_file 
--inputs='input'    # nazwy węzłów wejściowych
--outputs='output'  # nazwy węzłów wyjściowych
--transforms='
  remove_device
  fold_constants(ignore_errors=true)

```

Jeden skrypt

Wszystkie powyższe operacje możemy szybko powtórzyć, korzystając z poniższego skryptu:

temp=output

# !!!! Ustaw ścieżkę do lokalnego repozytorium TF !!!!
tensorflow_path=/home/user/Desktop/tensorflow

# Input
path=$1
model_name=$2
input_node=$3
output_node=$4

mkdir $temp

python freeze.py $path $temp $model_name $output_node

# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/#add_default_attributes

input_file=$temp/$model_name.pb
output_file=$temp/$model_name"_opt.pb"

echo $input_file $output_file

$tensorflow_path/bazel-bin/tensorflow/tools/graph_transforms/transform_graph 
--in_graph=$input_file 
--out_graph=$output_file 
--inputs=$input_node 
--outputs=$output_node 
--transforms='
  remove_device
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  strip_unused_nodes
  remove_attribute(attribute_name=_class)'


$tensorflow_path/bazel-bin/tensorflow/tools/graph_transforms/summarize_graph 
--in_graph=$output_file

Na koniec działania, skrypt wyświetli nam podsumowanie dla danej sieci, w tym m.in. liczbę jej parametrów.

TF Mobile czy TF Lite

Przez pewien czas istniały równocześnie dwa standardy: starszy TF Mobile oraz nowszy TF Lite. Świadomie trzymałem się starszego, gdyż nowszy przez długi czas nie posiadał zaimplementowanych wielu operacji, np.: funkcji ELU. Obecnie TF Mobile jest porzucany na rzecz TF Lite i na początku 2019 ten pierwszy standard stanie się przestarzały.

TF Mobile

Otwieramy nasz projekt z aplikacją androidową i umieszczamy nasz plik model_opt.pb w folderze assets.

W pliku build.gradle dodajemy najnowszą wersję biblioteki TensorFlow

implementation 'org.tensorflow:tensorflow-android:+'

Następnie stwórzmy sobie klasę pomocniczą:

```java
package com.stethome.androidtf;
import android.content.Context;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;


public class TensorflowModel {

    // Model
    private static final String MODEL_FILE = "file:///android_asset/model_opt.pb";
    private static final String INPUT_NODE = "input";
    private static final String OUTPUT_NODE = "output";
    private static final long[] INPUT_SIZE = {1, 10};

    // Tensorflow interface
    private TensorFlowInferenceInterface inferenceInterface;

    public TensorflowModel(Context context) {
        this.inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
    }

    public float[] predict(float[] input) {
        this.inferenceInterface.feed(INPUT_NODE, input, INPUT_SIZE);
        this.inferenceInterface.run(new String[]{OUTPUT_NODE});

        // Result
        float[] result = new float[1]; // wielkość wyjścia
        inferenceInterface.fetch(OUTPUT_NODE, result);
        return result;
    }

Jej użycie w aplikacji będzie wyglądać następująco:

```java
package com.stethome.androidtf;

import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.util.TimingLogger;

import java.util.Arrays;
import java.util.concurrent.TimeUnit;

public class MainActivity extends AppCompatActivity {

    private TensorflowModel tfModel;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        tfModel = new TensorflowModel(this);

        // Dummy data
        float[] input = new float[10];

        // Predykcje
        float[] output = tfModel.predict(input);
        Log.d("OUTPUT", Arrays.toString(output));
        Log.d("OUTPUT SIZE", Integer.toString(output.length));

    }
}

```

TF Lite

W celu skonwertowania naszego modelu do nowego standardu, użyjmy kolejnego narzędzia z repozytorium TF:

```bash
input_file=model_opt.pb
output_file=model.lite
input_node_name=input
output_node_name=output
/home/some/path/tensorflow/bazel-bin/tensorflow/contrib/lite/toco/toco 
--input_file=$input_file 
--input_format=TENSORFLOW_GRAPHDEF 
--output_format=TFLITE 
--output_file=$output_file 
--inference_type=FLOAT 
--input_type=FLOAT 
--input_arrays=$input_node_name 
--output_arrays=$output_node_name 
--input_shapes=1,10 # 
```

Dodajemy plik model.lite do folderu assets.

W pliku build.gradle dodajemy najnowszą wersję biblioteki TensorFlow

implementation ‘org.tensorflow:tensorflow-lite:+’

Ładujemy plik z modelem:

private MappedByteBuffer loadModel(Activity activity,String MODEL_FILE) throws IOException {

AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE);

FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();

long startOffset = fileDescriptor.getStartOffset();
long length = fileDescriptor.getDeclaredLength();

return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, length);
}

Uruchamiamy interpreter TF Lite:

import org.tensorflow.lite.Interpreter;
String MODEL="model.lite";
Interpreter tflite;
try {
   tflite=new Interpreter(loadModelFile(MainActivity.this,modelFile));
} catch (IOException e) {
   e.printStackTrace();
}

A następnie wykonujemy predykcję:

float[][] input=new float[1][10];
float[][] output=new float[][]{{0}};
tflite.run(inp,out);

Podsumowanie

Jak widać, wykorzystanie sieci neuronowej do uruchamiania bezpośrednio na urządzeniu z Androidem nie jest skomplikowane. Myślę, że TensorFlow jest najwygodniejszym do tego narzędziem.

baner

Podobne artykuły

[wpdevart_facebook_comment curent_url="https://geek.justjoin.it/ai-aplikacji-czyli-tensorflow-androidzie/" order_type="social" width="100%" count_of_comments="8" ]