diff --git a/src/async_store.rs b/src/async_store.rs index dd1e5fa..62814f8 100644 --- a/src/async_store.rs +++ b/src/async_store.rs @@ -9,7 +9,7 @@ use bdk_chain::{ use bitcoin::{Amount, BlockHash, OutPoint, ScriptBuf, Transaction, TxOut, Txid, consensus}; use sqlx::{ Row, - sqlite::{SqliteConnectOptions, SqlitePool as Pool}, + sqlite::{SqliteConnectOptions, SqliteConnection, SqlitePool as Pool}, }; use crate::Error; @@ -62,48 +62,41 @@ impl Store { impl Store { /// Write tx_graph. pub async fn write_tx_graph( - &self, + conn: &mut SqliteConnection, tx_graph: &tx_graph::ChangeSet, ) -> Result<(), Error> { - let txs = &tx_graph.txs; - let txouts = &tx_graph.txouts; - let anchors = &tx_graph.anchors; - let first_seen = &tx_graph.first_seen; - let last_seen = &tx_graph.last_seen; - let last_evicted = &tx_graph.last_evicted; - - for tx in txs { + for tx in &tx_graph.txs { let txid = tx.compute_txid(); sqlx::query( "INSERT INTO tx(txid, tx) VALUES($1, $2) ON CONFLICT DO UPDATE SET tx = $2", ) .bind(txid.to_string()) .bind(consensus::encode::serialize(tx)) - .execute(&self.pool) + .execute(&mut *conn) .await?; } - for (txid, t) in first_seen { + for (txid, t) in &tx_graph.first_seen { sqlx::query("INSERT INTO tx(txid, first_seen) VALUES($1, $2) ON CONFLICT DO UPDATE SET first_seen = $2") .bind(txid.to_string()) .bind(i64::try_from(*t)?) - .execute(&self.pool) + .execute(&mut *conn) .await?; } - for (txid, t) in last_seen { + for (txid, t) in &tx_graph.last_seen { sqlx::query("INSERT INTO tx(txid, last_seen) VALUES($1, $2) ON CONFLICT DO UPDATE SET last_seen = $2") .bind(txid.to_string()) .bind(i64::try_from(*t)?) - .execute(&self.pool) + .execute(&mut *conn) .await?; } - for (txid, t) in last_evicted { + for (txid, t) in &tx_graph.last_evicted { sqlx::query("INSERT INTO tx(txid, last_evicted) VALUES($1, $2) ON CONFLICT DO UPDATE SET last_evicted = $2") .bind(txid.to_string()) .bind(i64::try_from(*t)?) - .execute(&self.pool) + .execute(&mut *conn) .await?; } - for (op, txout) in txouts { + for (op, txout) in &tx_graph.txouts { let OutPoint { txid, vout } = op; let TxOut { value, @@ -114,10 +107,10 @@ impl Store { .bind(vout) .bind(i64::try_from(value.to_sat())?) .bind(script_pubkey.to_bytes()) - .execute(&self.pool) + .execute(&mut *conn) .await?; } - for (anchor, txid) in anchors { + for (anchor, txid) in &tx_graph.anchors { let BlockId { height, hash } = anchor.block_id; let confirmation_time = anchor.confirmation_time; sqlx::query("INSERT OR IGNORE INTO anchor(block_height, block_hash, txid, confirmation_time) VALUES($1, $2, $3, $4)") @@ -125,7 +118,7 @@ impl Store { .bind(hash.to_string()) .bind(txid.to_string()) .bind(i64::try_from(confirmation_time)?) - .execute(&self.pool) + .execute(&mut *conn) .await?; } @@ -134,7 +127,7 @@ impl Store { /// Write local_chain. pub async fn write_local_chain( - &self, + conn: &mut SqliteConnection, local_chain: &local_chain::ChangeSet, ) -> Result<(), Error> { for (&height, hash) in &local_chain.blocks { @@ -143,13 +136,13 @@ impl Store { sqlx::query("INSERT OR IGNORE INTO block(height, hash) VALUES($1, $2)") .bind(height) .bind(hash.to_string()) - .execute(&self.pool) + .execute(&mut *conn) .await?; } None => { sqlx::query("DELETE FROM block WHERE height = $1") .bind(height) - .execute(&self.pool) + .execute(&mut *conn) .await?; } } @@ -160,7 +153,7 @@ impl Store { /// Write keychain_txout. pub async fn write_keychain_txout( - &self, + conn: &mut SqliteConnection, keychain_txout: &keychain_txout::ChangeSet, ) -> Result<(), Error> { for (descriptor_id, last_revealed) in &keychain_txout.last_revealed { @@ -169,7 +162,7 @@ impl Store { ) .bind(descriptor_id.to_string()) .bind(last_revealed) - .execute(&self.pool) + .execute(&mut *conn) .await?; } for (descriptor_id, spk_cache) in &keychain_txout.spk_cache { @@ -180,7 +173,7 @@ impl Store { .bind(descriptor_id.to_string()) .bind(*derivation_index) .bind(script.to_bytes()) - .execute(&self.pool) + .execute(&mut *conn) .await?; } } @@ -189,12 +182,14 @@ impl Store { } /// Read tx_graph. - pub async fn read_tx_graph(&self) -> Result, Error> { + pub async fn read_tx_graph( + conn: &mut SqliteConnection, + ) -> Result, Error> { let mut changeset = tx_graph::ChangeSet::default(); let rows: Vec = sqlx::query_as("SELECT txid, tx, first_seen, last_seen, last_evicted FROM tx") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let txid: Txid = row.txid.parse()?; @@ -216,7 +211,7 @@ impl Store { } let rows = sqlx::query("SELECT txid, vout, value, script FROM txout") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let txid: String = row.get("txid"); @@ -236,7 +231,7 @@ impl Store { let rows = sqlx::query("SELECT block_height, block_hash, txid, confirmation_time FROM anchor") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let height: u32 = row.get("block_height"); @@ -256,11 +251,13 @@ impl Store { } /// Read local_chain. - pub async fn read_local_chain(&self) -> Result { + pub async fn read_local_chain( + conn: &mut SqliteConnection, + ) -> Result { let mut changeset = local_chain::ChangeSet::default(); let rows = sqlx::query("SELECT height, hash FROM block") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let height: u32 = row.get("height"); @@ -273,11 +270,13 @@ impl Store { } /// Read keychain_txout. - pub async fn read_keychain_txout(&self) -> Result { + pub async fn read_keychain_txout( + conn: &mut SqliteConnection, + ) -> Result { let mut changeset = keychain_txout::ChangeSet::default(); let rows = sqlx::query("SELECT descriptor_id, last_revealed FROM keychain_last_revealed") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let descriptor_id: String = row.get("descriptor_id"); @@ -289,7 +288,7 @@ impl Store { let rows = sqlx::query( "SELECT descriptor_id, derivation_index, script FROM keychain_script_pubkey", ) - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { @@ -338,18 +337,25 @@ mod test { let store = Store::new_memory().await?; store.migrate().await?; - store - .write_local_chain(&cs) - .await - .expect("failed to write `local_chain`"); + + { + let mut txn = store.pool.begin().await?; + Store::write_local_chain(&mut txn, &cs) + .await + .expect("failed to write `local_chain`"); + txn.commit().await?; + } // Trying to replace the value of existing height should be ignored. cs.blocks.insert(1, Some(Hash::hash(b"1a"))); - store - .write_local_chain(&cs) - .await - .expect("failed to write `local_chain`"); + { + let mut txn = store.pool.begin().await?; + Store::write_local_chain(&mut txn, &cs) + .await + .expect("failed to write `local_chain`"); + txn.commit().await?; + } let rows = sqlx::query("SELECT height, hash FROM block WHERE height = 1") .fetch_all(&store.pool) @@ -365,16 +371,24 @@ mod test { // Delete row 1 and insert hash "1a" again. let mut cs = local_chain::ChangeSet::default(); cs.blocks.insert(1, None); - store - .write_local_chain(&cs) - .await - .expect("failed to write `local_chain`"); + + { + let mut txn = store.pool.begin().await?; + Store::write_local_chain(&mut txn, &cs) + .await + .expect("failed to write `local_chain`"); + txn.commit().await?; + } cs.blocks.insert(1, Some(Hash::hash(b"1a"))); - store - .write_local_chain(&cs) - .await - .expect("failed to write `local_chain`"); + + { + let mut txn = store.pool.begin().await?; + Store::write_local_chain(&mut txn, &cs) + .await + .expect("failed to write `local_chain`"); + txn.commit().await?; + } let rows = sqlx::query("SELECT height, hash FROM block WHERE height = 1") .fetch_all(&store.pool) diff --git a/src/wallet.rs b/src/wallet.rs index ab45d95..72c47f7 100644 --- a/src/wallet.rs +++ b/src/wallet.rs @@ -8,7 +8,7 @@ use bdk_wallet::{AsyncWalletPersister, ChangeSet, KeychainKind, locked_outpoints use bitcoin::Network; use bitcoin::OutPoint; use miniscript::descriptor::{Descriptor, DescriptorPublicKey}; -use sqlx::Row; +use sqlx::{Row, sqlite::SqliteConnection}; use crate::Error; use crate::Store; @@ -16,8 +16,10 @@ use crate::Store; impl Store { /// Write changeset. pub async fn write_changeset(&self, changeset: &ChangeSet) -> Result<(), Error> { + let mut txn = self.pool.begin().await?; + if let Some(network) = changeset.network { - self.write_network(network).await?; + Self::write_network(&mut txn, network).await?; } let mut descriptors = BTreeMap::new(); @@ -27,22 +29,22 @@ impl Store { if let Some(ref change_descriptor) = changeset.change_descriptor { descriptors.insert(KeychainKind::Internal, change_descriptor.clone()); } - self.write_keychain_descriptors(descriptors).await?; + Self::write_keychain_descriptors(&mut txn, &descriptors).await?; - self.write_local_chain(&changeset.local_chain).await?; - self.write_tx_graph(&changeset.tx_graph).await?; - self.write_keychain_txout(&changeset.indexer).await?; - self.write_locked_outpoints(&changeset.locked_outpoints) - .await?; + Self::write_local_chain(&mut txn, &changeset.local_chain).await?; + Self::write_tx_graph(&mut txn, &changeset.tx_graph).await?; + Self::write_keychain_txout(&mut txn, &changeset.indexer).await?; + Self::write_locked_outpoints(&mut txn, &changeset.locked_outpoints).await?; + txn.commit().await?; Ok(()) } /// Write network. - pub async fn write_network(&self, network: Network) -> Result<(), Error> { + pub async fn write_network(conn: &mut SqliteConnection, network: Network) -> Result<(), Error> { sqlx::query("INSERT OR IGNORE INTO network(network) VALUES($1)") .bind(network.to_string()) - .execute(&self.pool) + .execute(&mut *conn) .await?; Ok(()) @@ -50,8 +52,8 @@ impl Store { /// Write keychain descriptors. pub async fn write_keychain_descriptors( - &self, - descriptors: BTreeMap>, + conn: &mut SqliteConnection, + descriptors: &BTreeMap>, ) -> Result<(), Error> { for (keychain, descriptor) in descriptors { let keychain = match keychain { @@ -61,7 +63,7 @@ impl Store { sqlx::query("INSERT OR IGNORE INTO keychain(keychain, descriptor) VALUES($1, $2)") .bind(keychain) .bind(descriptor.to_string()) - .execute(&self.pool) + .execute(&mut *conn) .await?; } @@ -70,17 +72,20 @@ impl Store { /// Read changeset. pub async fn read_changeset(&self) -> Result { - let network = self.read_network().await?; + let mut txn = self.pool.begin().await?; + + let network = Self::read_network(&mut txn).await?; - let descriptors = self.read_keychain_descriptors().await?; + let descriptors = Self::read_keychain_descriptors(&mut txn).await?; let descriptor = descriptors.get(&KeychainKind::External).cloned(); let change_descriptor = descriptors.get(&KeychainKind::Internal).cloned(); - let tx_graph = self.read_tx_graph().await?; - let local_chain = self.read_local_chain().await?; - let indexer = self.read_keychain_txout().await?; - let locked_outpoints = self.read_locked_outpoints().await?; + let tx_graph = Self::read_tx_graph(&mut txn).await?; + let local_chain = Self::read_local_chain(&mut txn).await?; + let indexer = Self::read_keychain_txout(&mut txn).await?; + let locked_outpoints = Self::read_locked_outpoints(&mut txn).await?; + txn.commit().await?; Ok(ChangeSet { network, descriptor, @@ -93,9 +98,9 @@ impl Store { } /// Read network. - pub async fn read_network(&self) -> Result, Error> { + pub async fn read_network(conn: &mut SqliteConnection) -> Result, Error> { let row = sqlx::query("SELECT network FROM network") - .fetch_optional(&self.pool) + .fetch_optional(&mut *conn) .await?; row.map(|row| { @@ -107,12 +112,12 @@ impl Store { /// Read keychain descriptors. pub async fn read_keychain_descriptors( - &self, + conn: &mut SqliteConnection, ) -> Result>, Error> { let mut descriptors = BTreeMap::new(); let rows = sqlx::query("SELECT keychain, descriptor FROM keychain") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let keychain: u8 = row.get("keychain"); @@ -134,7 +139,7 @@ impl Store { /// Write locked outpoints. pub async fn write_locked_outpoints( - &self, + conn: &mut SqliteConnection, locked_outpoints: &locked_outpoints::ChangeSet, ) -> Result<(), Error> { for (&outpoint, &is_locked) in &locked_outpoints.outpoints { @@ -143,13 +148,13 @@ impl Store { sqlx::query("INSERT OR IGNORE INTO locked_outpoint(txid, vout) VALUES($1, $2)") .bind(txid.to_string()) .bind(vout) - .execute(&self.pool) + .execute(&mut *conn) .await?; } else { sqlx::query("DELETE FROM locked_outpoint WHERE txid = $1 AND vout = $2") .bind(txid.to_string()) .bind(vout) - .execute(&self.pool) + .execute(&mut *conn) .await?; } } @@ -158,11 +163,13 @@ impl Store { } /// Read locked outpoints. - pub async fn read_locked_outpoints(&self) -> Result { + pub async fn read_locked_outpoints( + conn: &mut SqliteConnection, + ) -> Result { let mut changeset = locked_outpoints::ChangeSet::default(); let rows = sqlx::query("SELECT txid, vout FROM locked_outpoint") - .fetch_all(&self.pool) + .fetch_all(&mut *conn) .await?; for row in rows { let txid: String = row.get("txid");