AI w aplikacji, czyli TensorFlow na Androidzie

Obecnie prawie każdy producent elektroniki twierdzi, że w jego sprzęcie mieszka potężny dżin zwany sztuczną inteligencją. Po smartfonach przychodzi pomału kolej na sprzęty AGD, takie jak pralka czy mikrofalówka. Choć oczywiście “AI” to marketingowy buzzword, bo póki co nie pogadamy z naszą lodówką zamawiając na wieczór schłodzone Martini (wstrząśnięte, nie zmieszane), to myślę, że warto przyjrzeć się bliżej technologiom pozwalającym na użycie sieci neuronowych na urządzeniach mobilnych.


Krzysztof Joachimiak. Machine Learning Engineer i Software Developer w StethoMe®. Zaczynał jako Android Developer, po czym przeskoczył do swojej ulubionej działki IT, czyli uczenia maszynowego. Stara się trzymać rękę na pulsie jeśli chodzi o rozwój tak zwanej – nieraz trochę na wyrost — sztucznej inteligencji. Ostatnio związany z poznańskim startupem z branży medycznej, StethoMe®.


Zdalnie czy lokalnie?

Jeśli piszemy aplikację na system Android, która ma korzystać z dobrodziejstw sztucznej inteligencji, stajemy przed wyborem — czy nasz model ma być uruchamiany na telefonie/tablecie, czy też ma funkcjonować jako web service odpytywany przez nasze urządzenie mobilne. Głównymi kryteriami przy podejmowaniu decyzji jest wielkość samego modelu, czas na wykonanie potrzebnych obliczeń, oraz to, czy model ma działać offline. W większości przypadków zdecydujemy się zapewne na zdalne API. Co jednak zrobić w przypadku, jeśli chcielibyśmy skorzystać z sieci neuronowej, która ma być uruchamiana bezpośrednio na urządzeniu z Androidem? Proponuję użyć dobrze znaną bibliotekę TensorFlow od Google.

Zanim zaczniemy

  1. Przygotuj odpowiednie środowisko. Osobiście pracowałem na Linuksie (Ubuntu 16.04) korzystając z Pythona 2.7.
  2. Zainstaluj TensorFlow.
  3. Zainstaluj najnowszą wersję narzędzia Bazel.
  4. Sklonuj repozytorium TensorFlow.
  5. Zbuduj narzędzia służące do transformowania i optymalizacji sieci
    1. bazel build sciezka_do_tensorflow/tools/graph_transforms:transform_graph
    2. bazel build sciezka_do_tensorflow/tools/graph_transforms:summarize_graph
    3. bazel build sciezka_do_tensorflow/contrib/lite/toco:toco

Eksportujemy naszą sieć

W tym tutorialu chodzi mi o pokazanie samego procesu eksportowania sieci, nie będę więc powielać treści mnóstwa tutoriali do TensorFlow i celowo pominę proces treningu sieci. Załóżmy, że będzie to banalnie prosta architektura, reprezentująca regresję logistyczną.

Podczas tworzenia modelu musimy pamiętać, żeby nazwać węzeł wejściowy oraz wyjściowy grafu; w przeciwnym wypadku będziemy zmuszeni do jego ręcznego przeglądania i znalezienia nazw nadanych automatycznie przez TensorFlow.

Podczas tworzenia modelu w TensorFlow warto zwrócić uwagę na dwie rzeczy:

  • Jeśli wykorzystujemy warstwy konwolucyjne, bezpieczniej jest użyć domyślnej kolejności wymiarów z kanałami na końcu. Tylko ta wersja jest obecnie wspierana przez implementację TensoFlow działającą na CPU.
  • Normalizacja batcha wymaga ustawiania odpowiedniej flagi na etapie predykcji (flaga training).

Po uruchomieniu tego skryptu, otrzymamy folder zawierający pięć plików, w tym <OUTPUT_NAME>.pbtx, zawierający strukturę grafu obliczeniowego, oraz cztery pliki wyprodukowane przez Savera, zawierające wagi modelu.

Następnym krokiem jest połączenie struktury z wagami w jednym pliku. W tym celu możemy użyć narzędzia freeze_graph, które należy zbudować podobnie jak pozostałe narzędzia wymienione przeze mnie w punkcie 5., lub użyć poniższego skryptu:

Jeśli zdecydujemy się na uruchomienie użycia skryptu, powinniśmy zastosować następującą komendę:

python freeze.py model model net output

W tym miejscu będziemy już mieć plik model.pb, który zawiera pełen model sieci — strukturę grafu obliczeniowego wraz z wagami. To jednak nie jest jeszcze ostatni krok. Powinniśmy dodatkowo zoptymalizować nasz graf, aby obliczenia były mniej zasobożerne (co zresztą jest szczególnie ważne na urządzeniu mobilnym!).

Optymalizacja sieci

W tym celu używamy uprzednio zbudowanego narzędzia graph_transform. Jako argumenty musimy przekazać nazwę pliku wejściowego (model.pb) oraz wyjściowego — dla odróżnienia nazwijmy go model_opt.pb. graph_transform oferuje szereg różnego rodzaju transformacji; ja polecam użyć przede wszystkim fold_constants, remove_device (bez tego graf będzie żądał odpalenia na konkretnym rodzaju urządzenia, np. na karcie graficznej, czego nie oferuje biblioteka TesnorFlow na Andoridzie). Pełną listę dostępnych transformacji można znaleźć tutaj. Wśród nich na pewno warto zwrócić uwagę na fold_batch_norm oraz fold_old_batch_norm używaną do optymalizacji obliczeń w warstwie normalizacji batcha.

Jeden skrypt

Wszystkie powyższe operacje możemy szybko powtórzyć, korzystając z poniższego skryptu:

Na koniec działania, skrypt wyświetli nam podsumowanie dla danej sieci, w tym m.in. liczbę jej parametrów.

TF Mobile czy TF Lite

Przez pewien czas istniały równocześnie dwa standardy: starszy TF Mobile oraz nowszy TF Lite. Świadomie trzymałem się starszego, gdyż nowszy przez długi czas nie posiadał zaimplementowanych wielu operacji, np.: funkcji ELU. Obecnie TF Mobile jest porzucany na rzecz TF Lite i na początku 2019 ten pierwszy standard stanie się przestarzały.

TF Mobile

Otwieramy nasz projekt z aplikacją androidową i umieszczamy nasz plik model_opt.pb w folderze assets.

W pliku build.gradle dodajemy najnowszą wersję biblioteki TensorFlow

implementation 'org.tensorflow:tensorflow-android:+'

Następnie stwórzmy sobie klasę pomocniczą:

Jej użycie w aplikacji będzie wyglądać następująco:

TF Lite

W celu skonwertowania naszego modelu do nowego standardu, użyjmy kolejnego narzędzia z repozytorium TF:

Dodajemy plik model.lite do folderu assets.

W pliku build.gradle dodajemy najnowszą wersję biblioteki TensorFlow

implementation ‘org.tensorflow:tensorflow-lite:+’

Ładujemy plik z modelem:

Uruchamiamy interpreter TF Lite:

A następnie wykonujemy predykcję:

Podsumowanie

Jak widać, wykorzystanie sieci neuronowej do uruchamiania bezpośrednio na urządzeniu z Androidem nie jest skomplikowane. Myślę, że TensorFlow jest najwygodniejszym do tego narzędziem.

Patronujemy

 
 
Polecamy
HRakterna środa — sztuczna inteligencja na LinkedInie