diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index bf162fa9c..d37f39881 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -1428,7 +1428,8 @@ class Channel(AbstractChannel): self.logger.info("settle_htlc") assert self.can_send_ctx_updates(), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}" htlc = self.hm.get_htlc_by_id(REMOTE, htlc_id) - assert htlc.payment_hash == sha256(preimage) + if htlc.payment_hash != sha256(preimage): + raise Exception("incorrect preimage for HTLC") assert htlc_id not in self.hm.log[REMOTE]['settles'] self.hm.send_settle(htlc_id) @@ -1450,7 +1451,8 @@ class Channel(AbstractChannel): """ self.logger.info("receive_htlc_settle") htlc = self.hm.get_htlc_by_id(LOCAL, htlc_id) - assert htlc.payment_hash == sha256(preimage) + if htlc.payment_hash != sha256(preimage): + raise RemoteMisbehaving("received incorrect preimage for HTLC") assert htlc_id not in self.hm.log[LOCAL]['settles'] with self.db_lock: self.hm.recv_settle(htlc_id) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 9e6442ec1..2c150d5b4 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -2048,15 +2048,21 @@ class LNWallet(LNWorker): return key_list def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True): - assert sha256(preimage) == payment_hash + if sha256(preimage) != payment_hash: + raise Exception("tried to save incorrect preimage for payment_hash") self.preimages[payment_hash.hex()] = preimage.hex() if write_to_disk: self.wallet.save_db() def get_preimage(self, payment_hash: bytes) -> Optional[bytes]: assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}" - r = self.preimages.get(payment_hash.hex()) - return bytes.fromhex(r) if r else None + preimage_hex = self.preimages.get(payment_hash.hex()) + if preimage_hex is None: + return None + preimage_bytes = bytes.fromhex(preimage_hex) + if sha256(preimage_bytes) != payment_hash: + raise Exception("found incorrect preimage for payment_hash") + return preimage_bytes def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]: """returns None if payment_hash is a payment we are forwarding"""