Browse Source

lnhtlc: fix adding htlc between sending commitment_signed and receiving revoke_and_ack

master
SomberNight 7 years ago committed by ThomasV
parent
commit
69bffac86a
  1. 3
      electrum/lnchannel.py
  2. 39
      electrum/lnhtlc.py
  3. 3
      electrum/lnpeer.py
  4. 2
      electrum/tests/test_lnchannel.py
  5. 21
      electrum/tests/test_lnhtlc.py

3
electrum/lnchannel.py

@ -153,7 +153,7 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked.deserialize(True)
log = state.get('log')
self.hm = HTLCManager(self.config[LOCAL].ctn if log else 0, self.config[REMOTE].ctn if log else 0, log)
self.hm = HTLCManager(self.config[LOCAL].ctn, self.config[REMOTE].ctn, log)
self.name = name
Logger.__init__(self)
@ -194,6 +194,7 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked = self.pending_commitment(REMOTE)
self.config[REMOTE] = self.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_pcp, next_per_commitment_point=None)
self.config[LOCAL] = self.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig)
self.hm.channel_open_finished()
self.set_state('OPENING')
def set_force_closed(self):

39
electrum/lnhtlc.py

@ -10,6 +10,11 @@ class HTLCManager:
def __init__(self, local_ctn=0, remote_ctn=0, log=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
# self.ctn_latest[sub] is the ctn for the latest (newest that has a valid sig) ctx of sub
self.ctn_latest = {LOCAL:local_ctn, REMOTE: remote_ctn} # FIXME does this need to be persisted?
# after sending commitment_signed but before receiving revoke_and_ack,
# self.ctn_latest[REMOTE] == self.ctn[REMOTE] + 1
# otherwise they are equal
self.expect_sig = {SENT: False, RECEIVED: False}
if log is None:
initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
@ -35,33 +40,39 @@ class HTLCManager:
x[sub]['adds'] = d
return x
def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0}
self.ctn_latest = {LOCAL: 0, REMOTE: 0}
def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
htlc_id = htlc.htlc_id
adds = self.log[LOCAL]['adds']
assert type(adds) is not str
adds[htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn[REMOTE]+1}
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest[REMOTE]+1}
self.expect_sig[SENT] = True
return htlc
def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
htlc_id = htlc.htlc_id
self.log[REMOTE]['adds'][htlc_id] = htlc
l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn[LOCAL]+1, REMOTE: None}
l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest[LOCAL]+1, REMOTE: None}
self.expect_sig[RECEIVED] = True
def send_ctx(self) -> None:
next_ctn = self.ctn[REMOTE] + 1
assert self.ctn_latest[REMOTE] == self.ctn[REMOTE], (self.ctn_latest[REMOTE], self.ctn[REMOTE])
self.ctn_latest[REMOTE] = self.ctn[REMOTE] + 1
for locked_in in self.log[REMOTE]['locked_in'].values():
if locked_in[REMOTE] is None:
locked_in[REMOTE] = next_ctn
locked_in[REMOTE] = self.ctn_latest[REMOTE]
self.expect_sig[SENT] = False
def recv_ctx(self) -> None:
next_ctn = self.ctn[LOCAL] + 1
assert self.ctn_latest[LOCAL] == self.ctn[LOCAL], (self.ctn_latest[LOCAL], self.ctn[LOCAL])
self.ctn_latest[LOCAL] = self.ctn[LOCAL] + 1
for locked_in in self.log[LOCAL]['locked_in'].values():
if locked_in[LOCAL] is None:
locked_in[LOCAL] = next_ctn
locked_in[LOCAL] = self.ctn_latest[LOCAL]
self.expect_sig[RECEIVED] = False
def send_rev(self) -> None:
@ -69,18 +80,18 @@ class HTLCManager:
for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[LOCAL][log_action].items():
if ctns[REMOTE] is None:
ctns[REMOTE] = self.ctn[REMOTE] + 1
ctns[REMOTE] = self.ctn_latest[REMOTE] + 1
def recv_rev(self) -> None:
self.ctn[REMOTE] += 1
for htlc_id, ctns in self.log[LOCAL]['locked_in'].items():
if ctns[LOCAL] is None:
assert ctns[REMOTE] == self.ctn[REMOTE]
ctns[LOCAL] = self.ctn[LOCAL] + 1
#assert ctns[REMOTE] == self.ctn[REMOTE] # FIXME I don't think this assert is correct
ctns[LOCAL] = self.ctn_latest[LOCAL] + 1
for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[REMOTE][log_action].items():
if ctns[LOCAL] is None:
ctns[LOCAL] = self.ctn[LOCAL] + 1
ctns[LOCAL] = self.ctn_latest[LOCAL] + 1
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -136,10 +147,10 @@ class HTLCManager:
return self.htlcs(subject, ctn)
def send_settle(self, htlc_id: int) -> None:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn[REMOTE] + 1}
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest[REMOTE] + 1}
def recv_settle(self, htlc_id: int) -> None:
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn[LOCAL] + 1, REMOTE: None}
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest[LOCAL] + 1, REMOTE: None}
def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -181,7 +192,7 @@ class HTLCManager:
if ctns[LOCAL] == ctn]
def send_fail(self, htlc_id: int) -> None:
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn[REMOTE] + 1}
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest[REMOTE] + 1}
def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn[LOCAL] + 1, REMOTE: None}
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest[LOCAL] + 1, REMOTE: None}

3
electrum/lnpeer.py

@ -1071,7 +1071,7 @@ class Peer(Logger):
remote_ctn = chan.get_current_ctn(REMOTE)
chan.onion_keys[htlc.htlc_id] = secret_key
self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route
self.logger.info(f"starting payment. route: {route}")
self.logger.info(f"starting payment. route: {route}. htlc: {htlc}")
self.send_message("update_add_htlc",
channel_id=chan.channel_id,
id=htlc.htlc_id,
@ -1271,6 +1271,7 @@ class Peer(Logger):
self._remote_changed_events[chan.channel_id].set()
self._remote_changed_events[chan.channel_id].clear()
self.lnworker.save_channel(chan)
self.maybe_send_commitment(chan)
def on_update_fee(self, payload):
channel_id = payload["channel_id"]

2
electrum/tests/test_lnchannel.py

@ -164,6 +164,8 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0)
bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0)
alice.hm.channel_open_finished()
bob.hm.channel_open_finished()
return alice, bob

21
electrum/tests/test_lnhtlc.py

@ -1,6 +1,6 @@
from pprint import pprint
import unittest
from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner
from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction
from electrum.lnhtlc import HTLCManager
from typing import NamedTuple
@ -135,3 +135,22 @@ class TestHTLCManager(unittest.TestCase):
htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False)
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
A = HTLCManager()
B = HTLCManager()
A.send_ctx()
B.recv_ctx()
B.send_rev()
ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
self.assertEqual([], A.current_htlcs(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE))
self.assertEqual([], A.pending_htlcs(LOCAL))
self.assertEqual([], A.pending_htlcs(REMOTE))
A.recv_rev()
self.assertEqual([], A.current_htlcs(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.pending_htlcs(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.pending_htlcs(REMOTE))

Loading…
Cancel
Save