diff --git a/app/src/main/java/to/bitkit/utils/NetworkValidationHelper.kt b/app/src/main/java/to/bitkit/utils/NetworkValidationHelper.kt new file mode 100644 index 000000000..5eb631705 --- /dev/null +++ b/app/src/main/java/to/bitkit/utils/NetworkValidationHelper.kt @@ -0,0 +1,50 @@ +package to.bitkit.utils + +import org.lightningdevkit.ldknode.Network + +/** + * Helper for validating Bitcoin network compatibility of addresses and invoices + */ +object NetworkValidationHelper { + + /** + * Infer the Bitcoin network from an on-chain address prefix + * @param address The Bitcoin address to check + * @return The detected network, or null if the address format is unrecognized + */ + fun getAddressNetwork(address: String): Network? { + val lowercased = address.lowercase() + + // Bech32/Bech32m addresses (order matters: check bcrt1 before bc1) + return when { + lowercased.startsWith("bcrt1") -> Network.REGTEST + lowercased.startsWith("bc1") -> Network.BITCOIN + lowercased.startsWith("tb1") -> Network.TESTNET + else -> { + // Legacy addresses - check first character + when (address.firstOrNull()) { + '1', '3' -> Network.BITCOIN + 'm', 'n', '2' -> Network.TESTNET // testnet and regtest share these + else -> null + } + } + } + } + + /** + * Check if an address/invoice network mismatches the current app network + * @param addressNetwork The network detected from the address/invoice + * @param currentNetwork The app's current network (typically Env.network) + * @return true if there's a mismatch (address won't work on current network) + */ + fun isNetworkMismatch(addressNetwork: Network?, currentNetwork: Network): Boolean { + if (addressNetwork == null) return false + + // Special case: regtest uses testnet prefixes (m, n, 2, tb1) + if (currentNetwork == Network.REGTEST && addressNetwork == Network.TESTNET) { + return false + } + + return addressNetwork != currentNetwork + } +} diff --git a/app/src/main/java/to/bitkit/viewmodels/AppViewModel.kt b/app/src/main/java/to/bitkit/viewmodels/AppViewModel.kt index 94c9a8cb2..222b1e219 100644 --- a/app/src/main/java/to/bitkit/viewmodels/AppViewModel.kt +++ b/app/src/main/java/to/bitkit/viewmodels/AppViewModel.kt @@ -28,6 +28,7 @@ import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -107,6 +108,7 @@ import to.bitkit.ui.shared.toast.ToastQueueManager import to.bitkit.ui.sheets.SendRoute import to.bitkit.ui.theme.TRANSITION_SCREEN_MS import to.bitkit.utils.Logger +import to.bitkit.utils.NetworkValidationHelper import to.bitkit.utils.jsonLogOf import to.bitkit.utils.timedsheets.TimedSheetManager import to.bitkit.utils.timedsheets.sheets.AppUpdateTimedSheet @@ -196,6 +198,7 @@ class AppViewModel @Inject constructor( registerSheet(highBalanceSheet) } private var isCompletingMigration = false + private var addressValidationJob: Job? = null fun setShowForgotPin(value: Boolean) { _showForgotPinSheet.value = value @@ -664,6 +667,7 @@ class AppViewModel @Inject constructor( } private fun resetAddressInput() { + addressValidationJob?.cancel() _sendUiState.update { state -> state.copy( addressInput = "", @@ -674,15 +678,128 @@ class AppViewModel @Inject constructor( private fun onAddressChange(value: String) { val valueWithoutSpaces = value.removeSpaces() - viewModelScope.launch { - val result = runCatching { decode(valueWithoutSpaces) } - _sendUiState.update { - it.copy( - addressInput = valueWithoutSpaces, - isAddressInputValid = result.isSuccess, + + // Update text immediately, reset validity until validation completes + _sendUiState.update { + it.copy( + addressInput = valueWithoutSpaces, + isAddressInputValid = false, + ) + } + + // Cancel pending validation + addressValidationJob?.cancel() + + // Skip validation for empty input + if (valueWithoutSpaces.isEmpty()) return + + // Start debounced validation + addressValidationJob = viewModelScope.launch { + delay(ADDRESS_VALIDATION_DEBOUNCE_MS) + validateAddressWithFeedback(valueWithoutSpaces) + } + } + + private suspend fun validateAddressWithFeedback(input: String) = withContext(bgDispatcher) { + val scanResult = runCatching { decode(input) } + + if (scanResult.isFailure) { + showAddressValidationError( + titleRes = R.string.other__scan_err_decoding, + descriptionRes = R.string.other__scan__error__generic, + testTag = "InvalidAddressToast", + ) + return@withContext + } + + when (val decoded = scanResult.getOrNull()) { + is Scanner.Lightning -> validateLightningInvoice(decoded.invoice) + is Scanner.OnChain -> validateOnChainAddress(decoded.invoice) + else -> _sendUiState.update { it.copy(isAddressInputValid = true) } + } + } + + private suspend fun validateLightningInvoice(invoice: LightningInvoice) { + if (invoice.isExpired) { + showAddressValidationError( + titleRes = R.string.other__scan_err_decoding, + descriptionRes = R.string.other__scan__error__expired, + testTag = "ExpiredLightningToast", + ) + return + } + + if (invoice.amountSatoshis > 0uL) { + val maxSendLightning = walletRepo.balanceState.value.maxSendLightningSats + if (maxSendLightning == 0uL || !lightningRepo.canSend(invoice.amountSatoshis)) { + val shortfall = invoice.amountSatoshis - maxSendLightning + showAddressValidationError( + titleRes = R.string.other__pay_insufficient_spending, + descriptionRes = R.string.other__pay_insufficient_spending_amount_description, + descriptionArgs = mapOf("amount" to shortfall.toString()), + testTag = "InsufficientSpendingToast", ) + return } } + + _sendUiState.update { it.copy(isAddressInputValid = true) } + } + + private fun validateOnChainAddress(invoice: OnChainInvoice) { + // Check network mismatch + val addressNetwork = NetworkValidationHelper.getAddressNetwork(invoice.address) + if (NetworkValidationHelper.isNetworkMismatch(addressNetwork, Env.network)) { + showAddressValidationError( + titleRes = R.string.other__scan_err_decoding, + descriptionRes = R.string.other__scan__error__generic, + testTag = "InvalidAddressToast", + ) + return + } + + val maxSendOnchain = walletRepo.balanceState.value.maxSendOnchainSats + + if (maxSendOnchain == 0uL) { + showAddressValidationError( + titleRes = R.string.other__pay_insufficient_savings, + descriptionRes = R.string.other__pay_insufficient_savings_description, + testTag = "InsufficientSavingsToast", + ) + return + } + + if (invoice.amountSatoshis > 0uL && invoice.amountSatoshis > maxSendOnchain) { + val shortfall = invoice.amountSatoshis - maxSendOnchain + showAddressValidationError( + titleRes = R.string.other__pay_insufficient_savings, + descriptionRes = R.string.other__pay_insufficient_savings_amount_description, + descriptionArgs = mapOf("amount" to shortfall.toString()), + testTag = "InsufficientSavingsToast", + ) + return + } + + _sendUiState.update { it.copy(isAddressInputValid = true) } + } + + private fun showAddressValidationError( + @StringRes titleRes: Int, + @StringRes descriptionRes: Int, + descriptionArgs: Map = emptyMap(), + testTag: String? = null, + ) { + _sendUiState.update { it.copy(isAddressInputValid = false) } + var description = context.getString(descriptionRes) + descriptionArgs.forEach { (key, value) -> + description = description.replace("{$key}", value) + } + toast( + type = Toast.ToastType.ERROR, + title = context.getString(titleRes), + description = description, + testTag = testTag, + ) } private fun onAddressContinue(data: String) { @@ -883,20 +1000,26 @@ class AppViewModel @Inject constructor( } } - @Suppress("LongMethod") + @Suppress("LongMethod", "CyclomaticComplexMethod") private suspend fun onScanOnchain(invoice: OnChainInvoice, scanResult: String) { + // Check network mismatch + val addressNetwork = NetworkValidationHelper.getAddressNetwork(invoice.address) + if (NetworkValidationHelper.isNetworkMismatch(addressNetwork, Env.network)) { + toast( + type = Toast.ToastType.ERROR, + title = context.getString(R.string.other__scan_err_decoding), + description = context.getString(R.string.other__scan__error__generic), + testTag = "InvalidAddressToast", + ) + return + } + val lnInvoice: LightningInvoice? = invoice.params?.get("lightning")?.let { bolt11 -> runCatching { decode(bolt11) }.getOrNull() ?.let { it as? Scanner.Lightning } ?.invoice ?.takeIf { invoice -> if (invoice.isExpired) { - toast( - type = Toast.ToastType.ERROR, - title = context.getString(R.string.other__scan_err_decoding), - description = context.getString(R.string.other__scan__error__expired), - ) - Logger.debug( "Lightning invoice expired in unified URI, defaulting to onchain-only", context = TAG @@ -943,6 +1066,31 @@ class AppViewModel @Inject constructor( return } + // Check on-chain balance before proceeding to amount screen + val maxSendOnchain = walletRepo.balanceState.value.maxSendOnchainSats + if (maxSendOnchain == 0uL) { + toast( + type = Toast.ToastType.ERROR, + title = context.getString(R.string.other__pay_insufficient_savings), + description = context.getString(R.string.other__pay_insufficient_savings_description), + testTag = "InsufficientSavingsToast", + ) + return + } + + // Check if on-chain invoice amount exceeds available balance + if (invoice.amountSatoshis > 0uL && invoice.amountSatoshis > maxSendOnchain) { + val shortfall = invoice.amountSatoshis - maxSendOnchain + toast( + type = Toast.ToastType.ERROR, + title = context.getString(R.string.other__pay_insufficient_savings), + description = context.getString(R.string.other__pay_insufficient_savings_amount_description) + .replace("{amount}", shortfall.toString()), + testTag = "InsufficientSavingsToast", + ) + return + } + Logger.info( when (invoice.amountSatoshis > 0u) { true -> "Found amount in invoice, proceeding to edit amount" @@ -964,6 +1112,7 @@ class AppViewModel @Inject constructor( type = Toast.ToastType.ERROR, title = context.getString(R.string.other__scan_err_decoding), description = context.getString(R.string.other__scan__error__expired), + testTag = "ExpiredLightningToast", ) return } @@ -972,10 +1121,14 @@ class AppViewModel @Inject constructor( if (quickPayHandled) return if (!lightningRepo.canSend(invoice.amountSatoshis)) { + val maxSendLightning = walletRepo.balanceState.value.maxSendLightningSats + val shortfall = invoice.amountSatoshis - maxSendLightning toast( type = Toast.ToastType.ERROR, - title = context.getString(R.string.wallet__error_insufficient_funds_title), - description = context.getString(R.string.wallet__error_insufficient_funds_msg) + title = context.getString(R.string.other__pay_insufficient_spending), + description = context.getString(R.string.other__pay_insufficient_spending_amount_description) + .replace("{amount}", shortfall.toString()), + testTag = "InsufficientSpendingToast", ) return } @@ -1700,6 +1853,7 @@ class AppViewModel @Inject constructor( } suspend fun resetSendState() { + addressValidationJob?.cancel() val speed = settingsStore.data.first().defaultTransactionSpeed val rates = let { // Refresh blocktank info to get latest fee rates @@ -2029,6 +2183,7 @@ class AppViewModel @Inject constructor( private const val REMOTE_RESTORE_NODE_RESTART_DELAY_MS = 500L private const val AUTH_CHECK_INITIAL_DELAY_MS = 1000L private const val AUTH_CHECK_SPLASH_DELAY_MS = 500L + private const val ADDRESS_VALIDATION_DEBOUNCE_MS = 1000L } } diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index 54f5d0b3c..4eec55d59 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -465,6 +465,7 @@ Insufficient Savings ₿ {amount} more needed to pay this Bitcoin invoice. More ₿ needed to pay this Bitcoin invoice. + More ₿ needed to pay this Lightning invoice. Insufficient Spending Balance ₿ {amount} more needed to pay this Lightning invoice. Open Phone Settings diff --git a/app/src/test/java/to/bitkit/utils/NetworkValidationHelperTest.kt b/app/src/test/java/to/bitkit/utils/NetworkValidationHelperTest.kt new file mode 100644 index 000000000..5fb0cc2d2 --- /dev/null +++ b/app/src/test/java/to/bitkit/utils/NetworkValidationHelperTest.kt @@ -0,0 +1,151 @@ +package to.bitkit.utils + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Test +import org.lightningdevkit.ldknode.Network + +class NetworkValidationHelperTest { + + // MARK: - getAddressNetwork Tests + + // Mainnet addresses + @Test + fun `getAddressNetwork - mainnet bech32`() { + val address = "bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4" + assertEquals(Network.BITCOIN, NetworkValidationHelper.getAddressNetwork(address)) + } + + @Test + fun `getAddressNetwork - mainnet bech32 uppercase`() { + val address = "BC1QW508D6QEJXTDG4Y5R3ZARVARY0C5XW7KV8F3T4" + assertEquals(Network.BITCOIN, NetworkValidationHelper.getAddressNetwork(address)) + } + + @Test + fun `getAddressNetwork - mainnet P2PKH`() { + val address = "1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2" + assertEquals(Network.BITCOIN, NetworkValidationHelper.getAddressNetwork(address)) + } + + @Test + fun `getAddressNetwork - mainnet P2SH`() { + val address = "3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy" + assertEquals(Network.BITCOIN, NetworkValidationHelper.getAddressNetwork(address)) + } + + // Testnet addresses + @Test + fun `getAddressNetwork - testnet bech32`() { + val address = "tb1qw508d6qejxtdg4y5r3zarvary0c5xw7kxpjzsx" + assertEquals(Network.TESTNET, NetworkValidationHelper.getAddressNetwork(address)) + } + + @Test + fun `getAddressNetwork - testnet P2PKH m prefix`() { + val address = "mipcBbFg9gMiCh81Kj8tqqdgoZub1ZJRfn" + assertEquals(Network.TESTNET, NetworkValidationHelper.getAddressNetwork(address)) + } + + @Test + fun `getAddressNetwork - testnet P2PKH n prefix`() { + val address = "n3ZddxzLvAY9o7184TB4c6FJasAybsw4HZ" + assertEquals(Network.TESTNET, NetworkValidationHelper.getAddressNetwork(address)) + } + + @Test + fun `getAddressNetwork - testnet P2SH`() { + val address = "2MzQwSSnBHWHqSAqtTVQ6v47XtaisrJa1Vc" + assertEquals(Network.TESTNET, NetworkValidationHelper.getAddressNetwork(address)) + } + + // Regtest addresses + @Test + fun `getAddressNetwork - regtest bech32`() { + val address = "bcrt1q6rhpng9evdsfnn833a4f4vej0asu6dk5srld6x" + assertEquals(Network.REGTEST, NetworkValidationHelper.getAddressNetwork(address)) + } + + // Edge cases + @Test + fun `getAddressNetwork - empty string`() { + assertNull(NetworkValidationHelper.getAddressNetwork("")) + } + + @Test + fun `getAddressNetwork - invalid address`() { + assertNull(NetworkValidationHelper.getAddressNetwork("invalid")) + } + + @Test + fun `getAddressNetwork - random text`() { + assertNull(NetworkValidationHelper.getAddressNetwork("test123")) + } + + // MARK: - isNetworkMismatch Tests + + @Test + fun `isNetworkMismatch - same network`() { + assertFalse(NetworkValidationHelper.isNetworkMismatch(Network.BITCOIN, Network.BITCOIN)) + assertFalse(NetworkValidationHelper.isNetworkMismatch(Network.TESTNET, Network.TESTNET)) + assertFalse(NetworkValidationHelper.isNetworkMismatch(Network.REGTEST, Network.REGTEST)) + } + + @Test + fun `isNetworkMismatch - different network`() { + assertTrue(NetworkValidationHelper.isNetworkMismatch(Network.BITCOIN, Network.TESTNET)) + assertTrue(NetworkValidationHelper.isNetworkMismatch(Network.BITCOIN, Network.REGTEST)) + assertTrue(NetworkValidationHelper.isNetworkMismatch(Network.TESTNET, Network.BITCOIN)) + } + + @Test + fun `isNetworkMismatch - regtest accepts testnet prefixes`() { + // Regtest should accept testnet prefixes (m, n, 2, tb1) + assertFalse(NetworkValidationHelper.isNetworkMismatch(Network.TESTNET, Network.REGTEST)) + } + + @Test + fun `isNetworkMismatch - testnet rejects regtest addresses`() { + // Testnet should NOT accept regtest-specific addresses (bcrt1) + assertTrue(NetworkValidationHelper.isNetworkMismatch(Network.REGTEST, Network.TESTNET)) + } + + @Test + fun `isNetworkMismatch - null address network`() { + // When address network is nil (unrecognized format), no mismatch + assertFalse(NetworkValidationHelper.isNetworkMismatch(null, Network.BITCOIN)) + assertFalse(NetworkValidationHelper.isNetworkMismatch(null, Network.REGTEST)) + } + + // MARK: - Integration Tests + + @Test + fun `mainnet address on regtest should mismatch`() { + val address = "bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4" + val addressNetwork = NetworkValidationHelper.getAddressNetwork(address) + assertTrue(NetworkValidationHelper.isNetworkMismatch(addressNetwork, Network.REGTEST)) + } + + @Test + fun `testnet address on regtest should not mismatch`() { + val address = "tb1qw508d6qejxtdg4y5r3zarvary0c5xw7kxpjzsx" + val addressNetwork = NetworkValidationHelper.getAddressNetwork(address) + assertFalse(NetworkValidationHelper.isNetworkMismatch(addressNetwork, Network.REGTEST)) + } + + @Test + fun `regtest address on mainnet should mismatch`() { + val address = "bcrt1q6rhpng9evdsfnn833a4f4vej0asu6dk5srld6x" + val addressNetwork = NetworkValidationHelper.getAddressNetwork(address) + assertTrue(NetworkValidationHelper.isNetworkMismatch(addressNetwork, Network.BITCOIN)) + } + + @Test + fun `legacy testnet address on regtest should not mismatch`() { + val address = "mipcBbFg9gMiCh81Kj8tqqdgoZub1ZJRfn" // m-prefix testnet + val addressNetwork = NetworkValidationHelper.getAddressNetwork(address) + assertFalse(NetworkValidationHelper.isNetworkMismatch(addressNetwork, Network.REGTEST)) + } +}