diff --git a/CMakeLists.txt b/CMakeLists.txt index ceeaf5075af..3fb267b566f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -495,7 +495,7 @@ if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0.0 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-stringop-overflow") endif() -if(ANDROID) +if(ANDROID AND (NOT ANDROID_DEBUG_SYMBOLS)) if(CMAKE_COMPILER_IS_GNUCXX) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s") else() diff --git a/android/.gitignore b/android/.gitignore index 9d8383a7d42..921f163c3ab 100644 --- a/android/.gitignore +++ b/android/.gitignore @@ -6,11 +6,5 @@ gradle/wrapper .idea/* .externalNativeBuild build -pytorch_android/src/main/cpp/libtorch_include/x86/** -pytorch_android/src/main/cpp/libtorch_include/x86_64/** -pytorch_android/src/main/cpp/libtorch_include/armeabi-v7a/** -pytorch_android/src/main/cpp/libtorch_include/arm64-v8a/** -pytorch_android/src/main/jniLibs/x86/** -pytorch_android/src/main/jniLibs/x86_64/** -pytorch_android/src/main/jniLibs/armeabi-v7a/** -pytorch_android/src/main/jniLibs/arm64-v8a/** +pytorch_android/src/main/cpp/libtorch_include/** +pytorch_android/src/main/jniLibs/** diff --git a/android/build_test_app.sh b/android/build_test_app.sh new file mode 100755 index 00000000000..489f08072ad --- /dev/null +++ b/android/build_test_app.sh @@ -0,0 +1,100 @@ +#!/bin/bash +set -eux + +PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" + +PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android +WORK_DIR=$PYTORCH_DIR + +echo "PYTORCH_DIR:$PYTORCH_DIR" +echo "WORK_DIR:$WORK_DIR" + +echo "ANDROID_HOME:$ANDROID_HOME" +if [ ! -z "$ANDROID_HOME" ]; then + echo "ANDROID_HOME not set; please set it to Android sdk directory" +fi + +if [ ! -d $ANDROID_HOME ]; then + echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?" + exit 1 +fi + +GRADLE_PATH=gradle +GRADLE_NOT_FOUND_MSG="Unable to find gradle, please add it to PATH or set GRADLE_HOME" + +if [ ! -x "$(command -v gradle)" ]; then + if [ -z "$GRADLE_HOME" ]; then + echo GRADLE_NOT_FOUND_MSG + exit 1 + fi + GRADLE_PATH=$GRADLE_HOME/bin/gradle + if [ ! -f "$GRADLE_PATH" ]; then + echo GRADLE_NOT_FOUND_MSG + exit 1 + fi +fi +echo "GRADLE_PATH:$GRADLE_PATH" + +ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64" +CUSTOM_ABIS_LIST=false +if [ $# -gt 0 ]; then + ABIS_LIST=$1 + CUSTOM_ABIS_LIST=true +fi + +echo "ABIS_LIST:$ABIS_LIST" + +LIB_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs +INCLUDE_DIR=$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include +mkdir -p $LIB_DIR +rm $LIB_DIR/* +mkdir -p $INCLUDE_DIR + +for abi in $(echo $ABIS_LIST | tr ',' '\n') +do +echo "abi:$abi" + +OUT_DIR=$WORK_DIR/build_android_$abi + +rm -rf $OUT_DIR +mkdir -p $OUT_DIR + +pushd $PYTORCH_DIR +python $PYTORCH_DIR/setup.py clean + +ANDROID_ABI=$abi BUILD_PYTORCH_MOBILE=1 VERBOSE=1 ANDROID_DEBUG_SYMBOLS=1 $PYTORCH_DIR/scripts/build_android.sh -DANDROID_CCACHE=$(which ccache) + +cp -R $PYTORCH_DIR/build_android/install/lib $OUT_DIR/ +cp -R $PYTORCH_DIR/build_android/install/include $OUT_DIR/ + +echo "$abi build output lib,include copied to $OUT_DIR" + +LIB_LINK_PATH=$LIB_DIR/$abi +INCLUDE_LINK_PATH=$INCLUDE_DIR/$abi + +rm -f $LIB_LINK_PATH +rm -f $INCLUDE_LINK_PATH + +ln -s $OUT_DIR/lib $LIB_LINK_PATH +ln -s $OUT_DIR/include $INCLUDE_LINK_PATH + +done + +# To set proxy for gradle add following lines to ./gradle/gradle.properties: +# systemProp.http.proxyHost=... +# systemProp.http.proxyPort=8080 +# systemProp.https.proxyHost=... +# systemProp.https.proxyPort=8080 + +if [ "$CUSTOM_ABIS_LIST" = true ]; then + NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug +else + NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug +fi + +find $PYTORCH_ANDROID_DIR -type f -name *apk + +find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r " + +popd + diff --git a/android/gradle.properties b/android/gradle.properties index ff63986f2d6..a2ccde9ed70 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -22,3 +22,8 @@ ANDROID_MAVEN_GRADLE_PLUGIN_VERSION=2.1 # Gradle internals org.gradle.internal.repository.max.retries=1 org.gradle.jvmargs=-XX:MaxMetaspaceSize=1024m + +android.useAndroidX=true +android.enableJetifier=true + +nativeLibsDoNotStrip=false diff --git a/android/pytorch_android/build.gradle b/android/pytorch_android/build.gradle index e055190bb49..5e402d291b0 100644 --- a/android/pytorch_android/build.gradle +++ b/android/pytorch_android/build.gradle @@ -42,6 +42,9 @@ android { } else { pickFirst '**/libfbjni.so' } + if (nativeLibsDoNotStrip) { + doNotStrip "**/*.so" + } } useLibrary 'android.test.runner' diff --git a/android/settings.gradle b/android/settings.gradle index 99e442b2ae2..ece0c5fbdf6 100644 --- a/android/settings.gradle +++ b/android/settings.gradle @@ -1,6 +1,8 @@ -include ':app', ':pytorch_android', ':fbjni', ':pytorch_android_torchvision', ':pytorch_host' +include ':app', ':pytorch_android', ':fbjni', ':pytorch_android_torchvision', ':pytorch_host', ':test_app' project(':fbjni').projectDir = file('libs/fbjni_local') project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision') project(':pytorch_host').projectDir = file('pytorch_android/host') +project(':test_app').projectDir = file('test_app/app') + diff --git a/android/test_app/.gitignore b/android/test_app/.gitignore new file mode 100644 index 00000000000..b90a745d805 --- /dev/null +++ b/android/test_app/.gitignore @@ -0,0 +1,9 @@ +local.properties +**/*.iml +.gradle +gradlew* +gradle/wrapper +.idea/* +.DS_Store +build +.externalNativeBuild diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle new file mode 100644 index 00000000000..35b01fe4c42 --- /dev/null +++ b/android/test_app/app/build.gradle @@ -0,0 +1,61 @@ +apply plugin: 'com.android.application' + +repositories { + jcenter() +} + +def props = new Properties() +file("../gradle.properties").withInputStream { props.load(it) } + +def buildConfigProps = { k -> "\"${props.get(k)}\"" } + +android { + compileSdkVersion 28 + buildToolsVersion "29.0.2" + defaultConfig { + applicationId "org.pytorch.testapp" + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + ndk { + abiFilters ABI_FILTERS.split(",") + } + buildConfigField ("String", "MODULE_ASSET_NAME", buildConfigProps('MODULE_ASSET_NAME')) + buildConfigField ("String", "LOGCAT_TAG", "@string/app_name") + addManifestPlaceholders([APP_NAME: "@string/app_name"]) + } + buildTypes { + debug { + minifyEnabled false + } + } + flavorDimensions "model" + productFlavors { + mobNet2Quant { + dimension "model" + applicationIdSuffix ".mobNet2Quant" + buildConfigField ("String", "MODULE_ASSET_NAME", buildConfigProps('MODULE_ASSET_NAME_MOBNET2_QUANT')) + addManifestPlaceholders([APP_NAME: "PyMobNet2Quant"]) + buildConfigField ("String", "LOGCAT_TAG", "\"pytorch-mobnet2q\"") + } + resnet18 { + dimension "model" + applicationIdSuffix ".resneti18" + buildConfigField ("String", "MODULE_ASSET_NAME", buildConfigProps('MODULE_ASSET_NAME_RESNET18')) + addManifestPlaceholders([APP_NAME: "PyResNet18"]) + buildConfigField ("String", "LOGCAT_TAG", "\"pytorch-resnet18\"") + } + } + packagingOptions { + pickFirst '**/libfbjni.so' + doNotStrip '**.so' + } +} + +dependencies { + implementation 'androidx.appcompat:appcompat:1.1.0' + + implementation project(':pytorch_android') + implementation project(':pytorch_android_torchvision') +} diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..f34d7d94573 --- /dev/null +++ b/android/test_app/app/src/main/AndroidManifest.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + diff --git a/android/test_app/app/src/main/assets/.gitignore b/android/test_app/app/src/main/assets/.gitignore new file mode 100644 index 00000000000..94548af5beb --- /dev/null +++ b/android/test_app/app/src/main/assets/.gitignore @@ -0,0 +1,3 @@ +* +*/ +!.gitignore diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java new file mode 100644 index 00000000000..33ce02790f0 --- /dev/null +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java @@ -0,0 +1,155 @@ +package org.pytorch.testapp; + +import android.content.Context; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.SystemClock; +import android.util.Log; +import android.widget.TextView; + +import org.pytorch.IValue; +import org.pytorch.Module; +import org.pytorch.Tensor; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.FloatBuffer; + +import androidx.annotation.Nullable; +import androidx.annotation.UiThread; +import androidx.annotation.WorkerThread; +import androidx.appcompat.app.AppCompatActivity; + +public class MainActivity extends AppCompatActivity { + + private static final String TAG = BuildConfig.LOGCAT_TAG; + private static final int TEXT_TRIM_SIZE = 4096; + + private TextView mTextView; + + protected HandlerThread mBackgroundThread; + protected Handler mBackgroundHandler; + private Module mModule; + private FloatBuffer mInputTensorBuffer; + private Tensor mInputTensor; + private StringBuilder mTextViewStringBuilder = new StringBuilder(); + + private final Runnable mModuleForwardRunnable = new Runnable() { + @Override + public void run() { + final Result result = doModuleForward(); + runOnUiThread(new Runnable() { + @Override + public void run() { + handleResult(result); + if (mBackgroundHandler != null) { + mBackgroundHandler.post(mModuleForwardRunnable); + } + } + }); + } + }; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + mTextView = findViewById(R.id.text); + startBackgroundThread(); + mBackgroundHandler.post(mModuleForwardRunnable); + } + + protected void startBackgroundThread() { + mBackgroundThread = new HandlerThread(TAG + "_bg"); + mBackgroundThread.start(); + mBackgroundHandler = new Handler(mBackgroundThread.getLooper()); + } + + @Override + protected void onDestroy() { + stopBackgroundThread(); + super.onDestroy(); + } + + protected void stopBackgroundThread() { + mBackgroundThread.quitSafely(); + try { + mBackgroundThread.join(); + mBackgroundThread = null; + mBackgroundHandler = null; + } catch (InterruptedException e) { + Log.e(TAG, "Error stopping background thread", e); + } + } + + @WorkerThread + @Nullable + protected Result doModuleForward() { + if (mModule == null) { + final String moduleFileAbsoluteFilePath = new File( + assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath(); + mModule = Module.load(moduleFileAbsoluteFilePath); + mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * 224 * 224); + mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, 224, 224}); + } + + final long startTime = SystemClock.elapsedRealtime(); + final long moduleForwardStartTime = SystemClock.elapsedRealtime(); + final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor(); + final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime; + final float[] scores = outputTensor.getDataAsFloatArray(); + final long analysisDuration = SystemClock.elapsedRealtime() - startTime; + + return new Result(scores, moduleForwardDuration, analysisDuration); + } + + public static String assetFilePath(Context context, String assetName) { + File file = new File(context.getFilesDir(), assetName); + if (file.exists() && file.length() > 0) { + return file.getAbsolutePath(); + } + + try (InputStream is = context.getAssets().open(assetName)) { + try (OutputStream os = new FileOutputStream(file)) { + byte[] buffer = new byte[4 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + os.flush(); + } + return file.getAbsolutePath(); + } catch (IOException e) { + Log.e(TAG, "Error process asset " + assetName + " to file path"); + } + return null; + } + + static class Result { + + private final float[] scores; + private final long totalDuration; + private final long moduleForwardDuration; + + public Result(float[] scores, long moduleForwardDuration, long totalDuration) { + this.scores = scores; + this.moduleForwardDuration = moduleForwardDuration; + this.totalDuration = totalDuration; + } + } + + @UiThread + protected void handleResult(Result result) { + String message = String.format("forwardDuration:%d", result.moduleForwardDuration); + Log.i(TAG, message); + mTextViewStringBuilder.insert(0, '\n').insert(0, message); + if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) { + mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length()); + } + mTextView.setText(mTextViewStringBuilder.toString()); + } +} diff --git a/android/test_app/app/src/main/res/layout/activity_main.xml b/android/test_app/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000000..c0939ebc0eb --- /dev/null +++ b/android/test_app/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,17 @@ + + + + + + \ No newline at end of file diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png new file mode 100644 index 00000000000..64ba76f75e9 Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher.png differ diff --git a/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png new file mode 100644 index 00000000000..dae5e082342 Binary files /dev/null and b/android/test_app/app/src/main/res/mipmap-mdpi/ic_launcher_round.png differ diff --git a/android/test_app/app/src/main/res/values/colors.xml b/android/test_app/app/src/main/res/values/colors.xml new file mode 100644 index 00000000000..69b22338c65 --- /dev/null +++ b/android/test_app/app/src/main/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/android/test_app/app/src/main/res/values/strings.xml b/android/test_app/app/src/main/res/values/strings.xml new file mode 100644 index 00000000000..b8c9ca13d68 --- /dev/null +++ b/android/test_app/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + PyTest + diff --git a/android/test_app/app/src/main/res/values/styles.xml b/android/test_app/app/src/main/res/values/styles.xml new file mode 100644 index 00000000000..5885930df6d --- /dev/null +++ b/android/test_app/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/android/test_app/gradle.properties b/android/test_app/gradle.properties new file mode 100644 index 00000000000..4037ede282f --- /dev/null +++ b/android/test_app/gradle.properties @@ -0,0 +1,7 @@ +android.useAndroidX=true +android.enableJetifier=true + +MODULE_ASSET_NAME_DEFAULT=mobilenet2q.pt +MODULE_ASSET_NAME_MOBNET2_QUANT=mobilenet2q.pt +MODULE_ASSET_NAME_RESNET18=resnet18.pt + diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py new file mode 100644 index 00000000000..8f9c8749199 --- /dev/null +++ b/android/test_app/make_assets.py @@ -0,0 +1,16 @@ +import torch +import torchvision + +print(torch.version.__version__) + +resnet18 = torchvision.models.resnet18(pretrained=True) +resnet18.eval() +resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save("app/src/main/assets/resnet18.pt") + +resnet50 = torchvision.models.resnet50(pretrained=True) +resnet50.eval() +torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save("app/src/main/assets/resnet50.pt") + +mobilenet2q = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True) +mobilenet2q.eval() +torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save("app/src/main/assets/mobilenet2q.pt") diff --git a/scripts/build_android.sh b/scripts/build_android.sh index 6cc93fae1d4..5cea0c220fb 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -106,6 +106,10 @@ CMAKE_ARGS+=("-DANDROID_ABI=$ANDROID_ABI") CMAKE_ARGS+=("-DANDROID_NATIVE_API_LEVEL=$ANDROID_NATIVE_API_LEVEL") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=rtti exceptions") +if [ "${ANDROID_DEBUG_SYMBOLS:-}" == '1' ]; then + CMAKE_ARGS+=("-DANDROID_DEBUG_SYMBOLS=1") +fi + # Use-specified CMake arguments go last to allow overridding defaults CMAKE_ARGS+=($@)