【发布时间】:2020-05-11 16:37:18
【问题描述】:
我正在尝试通过使用 Firebase 机器学习工具包在 Android 中使用我的模型。
我尝试指定不同的输入,但没有成功。 我需要找到一种方法来使用从 Firebase 获取的 TensorFlow 模型在 Android 中进行预测。
目前我只能在 Android 输入中输入一个值。 如何在Android中指定2个输入,以便一个输入用于用户ID,另一个用于电影ID?
private void setupModel() {
FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("Recommender-Model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnCompleteListener(new OnCompleteListener<Void>() {
@Override
public void onComplete(@NonNull Task<Void> task) {
if (task.isSuccessful()) {
Toast.makeText(getApplicationContext(), "Downloaded", Toast.LENGTH_SHORT).show();
} else {
Toast.makeText(getApplicationContext(), "Download failure!", Toast.LENGTH_SHORT).show();
}
}
});
FirebaseModelInputOutputOptions inputOutputOptions = null;
try {
inputOutputOptions = new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1})
.build();
} catch (FirebaseMLException e) {
e.printStackTrace();
}
float[][] input = new float[1][1];
input[0][0] = 1f;
FirebaseModelInputs inputs = null;
try {
inputs = new FirebaseModelInputs.Builder()
.add(input)
.build();
} catch (FirebaseMLException e) {
e.printStackTrace();
}
FirebaseModelInterpreterOptions interpreterOptions =
new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
try {
FirebaseModelInterpreter.getInstance(interpreterOptions).run(inputs, inputOutputOptions)
.addOnSuccessListener(
new OnSuccessListener<FirebaseModelOutputs>() {
@Override
public void onSuccess(FirebaseModelOutputs result) {
float[][] predictedRating = result.getOutput(0);
Toast.makeText(getApplicationContext(), "Result Fetched", Toast.LENGTH_SHORT).show();
}
})
.addOnFailureListener(
new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
Toast.makeText(getApplicationContext(), "Failure", Toast.LENGTH_SHORT).show();
}
});
} catch (FirebaseMLException e) {
e.printStackTrace();
}
}
TensorFlow 中的预测函数如下所示:
model = Model(inputs = [u, m], outputs = x)
model.predict([test_user, test_movie], batch_size = 500)
【问题讨论】:
标签: java android firebase tensorflow recommender-systems