diff --git a/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt b/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt index a55b193..092aaff 100644 --- a/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt +++ b/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt @@ -29,7 +29,11 @@ enum class ModelOption( val size: String? = null, val supportsScreenshot: Boolean = true, val isOfflineModel: Boolean = false, - val offlineModelFilename: String? = null + val offlineModelFilename: String? = null, + val offlineAlternateModelFilenames: List = emptyList(), + val offlineRequiredFilenames: List = emptyList(), + val additionalDownloadUrls: List = emptyList(), + val requiresVisionBackend: Boolean = false ) { PUTER_GLM5("GLM-5V Turbo (Puter)", "openrouter:z-ai/glm-5v-turbo", ApiProvider.PUTER, supportsScreenshot = true), MISTRAL_LARGE_3("Mistral Large 3", "mistral-large-latest", ApiProvider.MISTRAL), @@ -53,15 +57,48 @@ enum class ModelOption( "https://huggingface.co/na5h13/gemma-3n-E4B-it-litert-lm/resolve/main/gemma-3n-E4B-it-int4.litertlm?download=true", "4.92 GB", isOfflineModel = true, - offlineModelFilename = "gemma-3n-e4b-it-int4.litertlm" + offlineModelFilename = "gemma-3n-e4b-it-int4.litertlm", + offlineRequiredFilenames = listOf("gemma-3n-e4b-it-int4.litertlm") ), GEMMA_4_E4B_IT( "Gemma 4 E4B it (offline)", "gemma-4-e4b-it", ApiProvider.GOOGLE, "https://huggingface.co/litert-community/gemma-4-E4B-it-litert-lm/resolve/main/gemma-4-E4B-it.litertlm?download=true", + "3.40 GB", isOfflineModel = true, - offlineModelFilename = "gemma-4-E4B-it.litertlm" + offlineModelFilename = "gemma-4-E4B-it.litertlm", + offlineRequiredFilenames = listOf("gemma-4-E4B-it.litertlm") + ), + QWEN3_5_4B_OFFLINE( + "Qwen3.5 4B (offline)", + "qwen3.5-4b-offline", + ApiProvider.GOOGLE, + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/model_multimodal.litertlm?download=true", + "6.3 GB", + isOfflineModel = true, + offlineModelFilename = "model_multimodal.litertlm", + offlineAlternateModelFilenames = listOf("model_quantized.litertlm"), + offlineRequiredFilenames = listOf( + "model_multimodal.litertlm", + "sentencepiece.model", + "tokenizer.json", + "tokenizer_config.json", + "embedder_quantized.tflite", + "vision_encoder_quantized.tflite", + "vision_adapter_quantized.tflite", + "model_multimodal_llm_metadata_multimodal.pb" + ), + additionalDownloadUrls = listOf( + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/sentencepiece.model?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/embedder_quantized.tflite?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/vision_encoder_quantized.tflite?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/vision_adapter_quantized.tflite?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/model_multimodal_llm_metadata_multimodal.pb?download=true" + ), + requiresVisionBackend = true ), HUMAN_EXPERT("Human Expert", "human-expert", ApiProvider.HUMAN_EXPERT); diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt index 6bc58cd..be0c4bc 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt @@ -58,13 +58,20 @@ object ModelDownloadManager { private var downloadJob: Job? = null private var isPaused = false + private data class DownloadTarget( + val finalFile: File, + val tempFile: File, + val url: String, + val label: String + ) + fun isModelDownloaded(context: Context, model: ModelOption = GenerativeAiViewModelFactory.getCurrentModel()): Boolean { - val file = getModelFile(context, model) - return file != null && file.exists() && file.length() > 0 + val modelFile = getModelFile(context, model) + return modelFile != null && modelFile.exists() && modelFile.length() > 0 } fun getModelFile(context: Context, model: ModelOption = GenerativeAiViewModelFactory.getCurrentModel()): File? { - val modelFilename = model.offlineModelFilename ?: return null + val modelFilename = resolveInstalledModelFilename(context, model) ?: model.offlineModelFilename ?: return null val externalFilesDir = context.getExternalFilesDir(null) return if (externalFilesDir != null) { File(externalFilesDir, modelFilename) @@ -74,14 +81,32 @@ object ModelDownloadManager { } } - private fun getTempFile(context: Context, model: ModelOption): File? { - val modelFilename = model.offlineModelFilename ?: return null - val externalFilesDir = context.getExternalFilesDir(null) - return if (externalFilesDir != null) { - File(externalFilesDir, modelFilename + TEMP_SUFFIX) + private fun getRequiredFiles(context: Context, model: ModelOption): List { + val externalFilesDir = context.getExternalFilesDir(null) ?: return emptyList() + val activeModelFilename = resolveInstalledModelFilename(context, model) + val requiredNames = if (model == ModelOption.QWEN3_5_4B_OFFLINE && activeModelFilename != null) { + // User requirement: first accept standalone model file and let runtime decide if add-ons are needed. + listOf(activeModelFilename) + } else if (model.offlineRequiredFilenames.isNotEmpty()) { + model.offlineRequiredFilenames } else { - null + listOfNotNull(model.offlineModelFilename) } + return requiredNames.map { File(externalFilesDir, it) } + } + + private fun resolveInstalledModelFilename(context: Context, model: ModelOption): String? { + val externalFilesDir = context.getExternalFilesDir(null) ?: return null + val candidates = listOfNotNull(model.offlineModelFilename) + model.offlineAlternateModelFilenames + return candidates.firstOrNull { name -> + val f = File(externalFilesDir, name) + f.exists() && f.length() > 0 + } + } + + fun getMissingRequiredFiles(context: Context, model: ModelOption): List { + val requiredFiles = getRequiredFiles(context, model) + return requiredFiles.filter { !it.exists() || it.length() <= 0 }.map { it.name } } private fun createNotificationChannel(context: Context) { @@ -147,7 +172,7 @@ object ModelDownloadManager { isPaused = false downloadJob = CoroutineScope(Dispatchers.IO).launch { - downloadWithResume(context, model, url) + downloadModelPackage(context, model, url) } } @@ -164,7 +189,7 @@ object ModelDownloadManager { isPaused = false downloadJob = CoroutineScope(Dispatchers.IO).launch { - downloadWithResume(context, model, url) + downloadModelPackage(context, model, url) } } @@ -174,11 +199,16 @@ object ModelDownloadManager { downloadJob?.cancel() downloadJob = null - // Delete temp file - val tempFile = getTempFile(context, model) - if (tempFile != null && tempFile.exists()) { - tempFile.delete() - Log.d(TAG, "Temp file deleted.") + // Delete temp files for full package + val externalFilesDir = context.getExternalFilesDir(null) + if (externalFilesDir != null) { + val targets = buildDownloadTargets(context, model, model.downloadUrl ?: "") + targets.forEach { target -> + if (target.tempFile.exists()) { + target.tempFile.delete() + } + } + Log.d(TAG, "Temporary package files deleted.") } _downloadState.value = DownloadState.Idle @@ -188,21 +218,79 @@ object ModelDownloadManager { } } - private suspend fun downloadWithResume(context: Context, model: ModelOption, url: String) { - val tempFile = getTempFile(context, model) ?: run { + private suspend fun downloadModelPackage(context: Context, model: ModelOption, primaryUrl: String) { + val targets = buildDownloadTargets(context, model, primaryUrl) + if (targets.isEmpty()) { _downloadState.value = DownloadState.Error("Storage not available.") return } - val finalFile = getModelFile(context, model) ?: run { - _downloadState.value = DownloadState.Error("Storage not available.") - return + + for ((index, target) in targets.withIndex()) { + if (!coroutineContext.isActive) return + Log.i(TAG, "Downloading package file ${index + 1}/${targets.size}: ${target.label}") + val error = downloadSingleFileWithResume(context, target, index, targets.size) + if (error != null) { + _downloadState.value = DownloadState.Error(error) + cancelDownloadNotification(context) + return + } + } + + _downloadState.value = DownloadState.Completed + showDownloadCompleteNotification(context) + withContext(Dispatchers.Main) { + Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show() + } + } + + private fun buildDownloadTargets(context: Context, model: ModelOption, primaryUrl: String): List { + val externalFilesDir = context.getExternalFilesDir(null) ?: return emptyList() + val primaryFilename = model.offlineModelFilename ?: return emptyList() + val urls = listOf(primaryUrl) + model.additionalDownloadUrls + val filenames = urls.mapIndexedNotNull { idx, url -> + if (idx == 0) primaryFilename else filenameFromUrl(url) + } + if (urls.size != filenames.size) { + Log.e(TAG, "Could not resolve filename for at least one download URL.") + return emptyList() + } + return urls.zip(filenames).map { (url, filename) -> + val finalFile = File(externalFilesDir, filename) + DownloadTarget( + finalFile = finalFile, + tempFile = File(externalFilesDir, "$filename$TEMP_SUFFIX"), + url = url, + label = filename + ) + } + } + + private fun filenameFromUrl(url: String): String? { + val clean = url.substringBefore('?') + val slash = clean.lastIndexOf('/') + return if (slash >= 0 && slash + 1 < clean.length) clean.substring(slash + 1) else null + } + + private suspend fun downloadSingleFileWithResume( + context: Context, + target: DownloadTarget, + fileIndex: Int, + fileCount: Int + ): String? { + val tempFile = target.tempFile + val finalFile = target.finalFile + val url = target.url + + if (finalFile.exists() && finalFile.length() > 0L) { + Log.d(TAG, "Skipping already downloaded file: ${target.label}") + return null } var retryCount = 0 var bytesDownloaded = if (tempFile.exists()) tempFile.length() else 0L while (retryCount <= MAX_RETRIES) { - if (!coroutineContext.isActive) return // Coroutine was cancelled + if (!coroutineContext.isActive) return null // Coroutine was cancelled var connection: HttpURLConnection? = null try { @@ -240,9 +328,7 @@ object ModelDownloadManager { } } else -> { - _downloadState.value = DownloadState.Error("Server error: $responseCode") - cancelDownloadNotification(context) - return + return "Server error for ${target.label}: $responseCode" } } @@ -264,7 +350,7 @@ object ModelDownloadManager { if (!coroutineContext.isActive) { Log.d(TAG, "Download cancelled during read.") cancelDownloadNotification(context) - return + return null } if (isPaused) { @@ -275,7 +361,7 @@ object ModelDownloadManager { ) // Keep notification showing paused state showDownloadNotification(context, bytesDownloaded.toFloat() / totalBytes, bytesDownloaded, totalBytes) - return + return null } output.write(buffer, 0, bytesRead) @@ -286,13 +372,14 @@ object ModelDownloadManager { if (now - lastProgressUpdate >= PROGRESS_UPDATE_INTERVAL_MS) { lastProgressUpdate = now val progress = if (totalBytes > 0) bytesDownloaded.toFloat() / totalBytes else 0f + val aggregateProgress = (fileIndex + progress) / fileCount.toFloat() _downloadState.value = DownloadState.Downloading( - progress = progress, + progress = aggregateProgress, bytesDownloaded = bytesDownloaded, totalBytes = totalBytes ) // Point 18: Update notification with progress - showDownloadNotification(context, progress, bytesDownloaded, totalBytes) + showDownloadNotification(context, aggregateProgress, bytesDownloaded, totalBytes) } } } @@ -303,30 +390,20 @@ object ModelDownloadManager { finalFile.delete() if (tempFile.renameTo(finalFile)) { Log.i(TAG, "Download complete! File: ${finalFile.absolutePath} (${finalFile.length()} bytes)") - _downloadState.value = DownloadState.Completed - showDownloadCompleteNotification(context) - withContext(Dispatchers.Main) { - Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show() - } } else { - _downloadState.value = DownloadState.Error("Failed to save model file.") - cancelDownloadNotification(context) + return "Failed to save ${target.label}." } } - return // Success, exit retry loop + return null // Success, exit retry loop } catch (e: IOException) { Log.e(TAG, "Download error (attempt ${retryCount + 1}): ${e.message}") retryCount++ if (retryCount > MAX_RETRIES) { - _downloadState.value = DownloadState.Error("Download failed after $MAX_RETRIES retries: ${e.message}") - cancelDownloadNotification(context) - withContext(Dispatchers.Main) { - Toast.makeText(context, "Download failed: ${e.message}", Toast.LENGTH_LONG).show() - } + return "Download failed for ${target.label} after $MAX_RETRIES retries: ${e.message}" } else { _downloadState.value = DownloadState.Downloading( - progress = if (bytesDownloaded > 0) 0f else 0f, + progress = fileIndex.toFloat() / fileCount.toFloat(), bytesDownloaded = bytesDownloaded, totalBytes = -1 ) @@ -337,6 +414,8 @@ object ModelDownloadManager { connection?.disconnect() } } + + return "Download failed for ${target.label}." } /** diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt index a059839..d94eef3 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt @@ -48,6 +48,7 @@ import androidx.localbroadcastmanager.content.LocalBroadcastManager import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.withContext +import java.io.File import java.io.IOException import java.util.concurrent.atomic.AtomicBoolean @@ -327,45 +328,92 @@ class PhotoReasoningViewModel( private fun initializeOfflineModel(context: Context): String? { try { val currentModel = com.google.ai.sample.GenerativeAiViewModelFactory.getCurrentModel() - val modelFile = ModelDownloadManager.getModelFile(context, currentModel) - if (modelFile != null && modelFile.exists()) { + val missingFiles = ModelDownloadManager.getMissingRequiredFiles(context, currentModel) + if (missingFiles.isNotEmpty()) { + return "Offline model files missing: ${missingFiles.joinToString(", ")}. Please redownload the model package." + } + val selectedModelFile = ModelDownloadManager.getModelFile(context, currentModel) + if (selectedModelFile != null && selectedModelFile.exists()) { // Load backend preference GenerativeAiViewModelFactory.loadBackendPreference(context) val backend = GenerativeAiViewModelFactory.getBackend() + val isLiteRtModel = currentModel.offlineModelFilename?.endsWith(".litertlm", ignoreCase = true) == true - if (currentModel == ModelOption.GEMMA_4_E4B_IT) { + if (isLiteRtModel) { if (!isLiteRtAbiSupported()) { - return "Gemma 4 offline is only supported on arm64-v8a or x86_64 devices." + return "Offline LiteRT models are only supported on arm64-v8a or x86_64 devices." + } + + val externalFilesDir = context.getExternalFilesDir(null) + val candidateNames = linkedSetOf().apply { + add(selectedModelFile.name) + currentModel.offlineModelFilename?.let { add(it) } + addAll(currentModel.offlineAlternateModelFilenames) } + val candidateFiles = candidateNames + .mapNotNull { name -> externalFilesDir?.let { File(it, name) } } + .filter { it.exists() && it.length() > 0L } + Log.i( TAG, - "Initializing Gemma 4 LiteRT engine. preferredBackend=$backend, " + + "Initializing LiteRT engine for ${currentModel.displayName}. preferredBackend=$backend, " + "abis=${Build.SUPPORTED_ABIS?.joinToString() ?: "unknown"}, " + - "modelPath=${modelFile.absolutePath}, modelSizeBytes=${modelFile.length()}" + "candidateFiles=${candidateFiles.joinToString { "${it.name}(${it.length()}B)" }}" ) + if (liteRtEngine == null) { val preferredBackend = if (backend == InferenceBackend.GPU) Backend.GPU() else Backend.CPU() - val preferredVisionBackend = if (currentModel.supportsScreenshot) Backend.GPU() else null - val audioBackend = null - val cacheDir = - if (modelFile.absolutePath.startsWith("/data/local/tmp")) { - context.getExternalFilesDir(null)?.absolutePath - } else { - null + + val attempts = if (candidateFiles.isNotEmpty()) candidateFiles else listOf(selectedModelFile) + val failureDetails = StringBuilder() + + attempts.forEachIndexed { index, modelFile -> + try { + val useVisionBackend = currentModel.requiresVisionBackend && + modelFile.name.contains("multimodal", ignoreCase = true) + val preferredVisionBackend = if (useVisionBackend) { + if (backend == InferenceBackend.GPU) Backend.GPU() else Backend.CPU() + } else { + null + } + val audioBackend = null + val cacheDir = + if (modelFile.absolutePath.startsWith("/data/local/tmp")) { + context.getExternalFilesDir(null)?.absolutePath + } else { + null + } + + Log.i( + TAG, + "LiteRT model file attempt ${index + 1}/${attempts.size}: " + + "modelFile=${modelFile.absolutePath}, size=${modelFile.length()}, useVision=$useVisionBackend" + ) + + liteRtEngine = createLiteRtEngineWithFallbacks( + modelPath = modelFile.absolutePath, + preferredBackend = preferredBackend, + preferredVisionBackend = preferredVisionBackend, + audioBackend = audioBackend, + cacheDir = cacheDir + ) + Log.d(TAG, "Offline model initialized with LiteRT-LM Engine using ${modelFile.name}") + return null + } catch (e: Exception) { + val msg = e.message ?: e.toString() + failureDetails.append("${modelFile.name}: $msg\n") + Log.e(TAG, "LiteRT file attempt failed for ${modelFile.name}", e) } - liteRtEngine = createLiteRtEngineWithFallbacks( - modelPath = modelFile.absolutePath, - preferredBackend = preferredBackend, - preferredVisionBackend = preferredVisionBackend, - audioBackend = audioBackend, - cacheDir = cacheDir + } + + throw IllegalStateException( + "All model-file attempts failed for ${currentModel.displayName}.\n$failureDetails" ) - Log.d(TAG, "Offline model initialized with LiteRT-LM Engine") } } else { if (llmInference == null) { val optionsBuilder = LlmInference.LlmInferenceOptions.builder() - .setModelPath(modelFile.absolutePath) + .setModelPath(selectedModelFile.absolutePath) .setMaxTokens(4096) // Set preferred backend (CPU or GPU) @@ -401,6 +449,11 @@ class PhotoReasoningViewModel( ) { return "LiteRT native runtime is not available on this device/ABI. Use an arm64-v8a or x86_64 build." } + if (msg.contains("litert_compiled_model", ignoreCase = true) || + msg.contains("litert_tensor_buffer", ignoreCase = true) + ) { + return "Offline model could not be initialized: LiteRT cannot compile this model package on this device. Check model files and try CPU backend." + } return if (msg.contains("memory", ignoreCase = true) || msg.contains("RAM", ignoreCase = true) || msg.contains("OOM", ignoreCase = true) || msg.contains("alloc", ignoreCase = true) || msg.contains("out of", ignoreCase = true)) { "Not enough RAM to load the model on GPU. Try switching to CPU." } else { @@ -835,7 +888,8 @@ class PhotoReasoningViewModel( // Initialize model if needed var initError: String? = null val selectedOfflineModel = GenerativeAiViewModelFactory.getCurrentModel() - if (selectedOfflineModel == ModelOption.GEMMA_4_E4B_IT) { + val useLiteRt = selectedOfflineModel.offlineModelFilename?.endsWith(".litertlm", ignoreCase = true) == true + if (useLiteRt) { if (liteRtEngine == null) { withContext(Dispatchers.Main) { replaceAiMessageText("Initializing offline model...", isPending = true) @@ -860,7 +914,7 @@ class PhotoReasoningViewModel( _isInitializingOfflineModelFlow.value = false } - if (selectedOfflineModel == ModelOption.GEMMA_4_E4B_IT && liteRtEngine == null) { + if (useLiteRt && liteRtEngine == null) { val errorMsg = initError ?: "Offline model could not be initialized." withContext(Dispatchers.Main) { _uiState.value = PhotoReasoningUiState.Error(errorMsg) @@ -875,7 +929,7 @@ class PhotoReasoningViewModel( refreshStopButtonState() } return@launch - } else if (selectedOfflineModel != ModelOption.GEMMA_4_E4B_IT && llmInference == null) { + } else if (!useLiteRt && llmInference == null) { val errorMsg = initError ?: "Offline model could not be initialized." withContext(Dispatchers.Main) { _uiState.value = PhotoReasoningUiState.Error(errorMsg) @@ -896,7 +950,7 @@ class PhotoReasoningViewModel( Log.d(TAG, "Sending streaming prompt to offline model (length: ${fullPrompt.length})") - val finalResponse = if (selectedOfflineModel == ModelOption.GEMMA_4_E4B_IT) { + val finalResponse = if (useLiteRt) { val engine = liteRtEngine if (engine == null) { withContext(Dispatchers.Main) {