diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 540102de2..778891526 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -11,7 +11,7 @@ from .crypto import sha256, hash_160 from .ecc import ECPrivkey from .bitcoin import (script_to_p2wsh, opcodes, p2wsh_nested_script, push_script, is_segwit_address, construct_witness) -from .transaction import PartialTxInput, PartialTxOutput, PartialTransaction, Transaction, TxInput +from .transaction import PartialTxInput, PartialTxOutput, PartialTransaction, Transaction, TxInput, TxOutpoint from .transaction import script_GetOp, match_script_against_template, OPPushDataGeneric, OPPushDataPubkey from .util import log_exceptions from .lnutil import REDEEM_AFTER_DOUBLE_SPENT_DELAY, ln_dummy_address @@ -92,7 +92,15 @@ class SwapData(StoredObject): funding_txid = attr.ib(type=Optional[str]) spending_txid = attr.ib(type=Optional[str]) is_redeemed = attr.ib(type=bool) - _funding_prevout = None # for RBF + + _funding_prevout = None # type: Optional[TxOutpoint] # for RBF + __payment_hash = None + + @property + def payment_hash(self) -> bytes: + if self.__payment_hash is None: + self.__payment_hash = sha256(self.preimage) + return self.__payment_hash def create_claim_tx( @@ -135,8 +143,14 @@ class SwapManager(Logger): self._max_amount = 0 self.wallet = wallet self.lnworker = lnworker + self.swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData] - self.prepayments = {} # type: Dict[bytes, bytes] # fee_preimage -> preimage + self._swaps_by_funding_outpoint = {} # type: Dict[TxOutpoint, SwapData] + self._swaps_by_lockup_address = {} # type: Dict[str, SwapData] + for payment_hash, swap in self.swaps.items(): + self._add_or_reindex_swap(swap) + + self.prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash for k, swap in self.swaps.items(): if swap.is_reverse and swap.prepay_hash is not None: self.prepayments[swap.prepay_hash] = bytes.fromhex(k) @@ -173,6 +187,7 @@ class SwapManager(Logger): continue swap.funding_txid = txin.prevout.txid.hex() swap._funding_prevout = txin.prevout + self._add_or_reindex_swap(swap) # to update _swaps_by_funding_outpoint spent_height = txin.spent_height if spent_height is not None: swap.spending_txid = txin.spent_txid @@ -319,7 +334,7 @@ class SwapManager(Logger): funding_txid = None, spending_txid = None, ) - self.swaps[payment_hash.hex()] = swap + self._add_or_reindex_swap(swap) self.add_lnwatcher_callback(swap) await self.network.broadcast_transaction(tx) return tx.txid() @@ -418,7 +433,7 @@ class SwapManager(Logger): funding_txid = None, spending_txid = None, ) - self.swaps[preimage_hash.hex()] = swap + self._add_or_reindex_swap(swap) # add callback to lnwatcher self.add_lnwatcher_callback(swap) # initiate payment. @@ -429,6 +444,13 @@ class SwapManager(Logger): success, log = await self.lnworker.pay_invoice(invoice, attempts=10) return success + def _add_or_reindex_swap(self, swap: SwapData) -> None: + if swap.payment_hash.hex() not in self.swaps: + self.swaps[swap.payment_hash.hex()] = swap + if swap._funding_prevout: + self._swaps_by_funding_outpoint[swap._funding_prevout] = swap + self._swaps_by_lockup_address[swap.lockup_address] = swap + async def get_pairs(self) -> None: assert self.network response = await self.network._send_http_on_proxy( @@ -543,16 +565,10 @@ class SwapManager(Logger): return self.get_swap_by_claim_txin(txin) def get_swap_by_claim_txin(self, txin: TxInput) -> Optional[SwapData]: - for key, swap in self.swaps.items(): - if txin.prevout == swap._funding_prevout: - return swap - return None + return self._swaps_by_funding_outpoint.get(txin.prevout) def is_lockup_address_for_a_swap(self, addr: str) -> bool: - for key, swap in self.swaps.items(): # TODO take lock? or add index to avoid looping - if addr == swap.lockup_address: - return True - return False + return bool(self._swaps_by_lockup_address.get(addr)) def add_txin_info(self, txin: PartialTxInput) -> None: """Add some info to a claim txin.