Browse Source

lnhtlc: fix adding htlc between sending commitment_signed and receiving revoke_and_ack

master
SomberNight 7 years ago committed by ThomasV
parent
commit
69bffac86a
  1. 3
      electrum/lnchannel.py
  2. 39
      electrum/lnhtlc.py
  3. 3
      electrum/lnpeer.py
  4. 2
      electrum/tests/test_lnchannel.py
  5. 21
      electrum/tests/test_lnhtlc.py

3
electrum/lnchannel.py

@ -153,7 +153,7 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked.deserialize(True) self.remote_commitment_to_be_revoked.deserialize(True)
log = state.get('log') log = state.get('log')
self.hm = HTLCManager(self.config[LOCAL].ctn if log else 0, self.config[REMOTE].ctn if log else 0, log) self.hm = HTLCManager(self.config[LOCAL].ctn, self.config[REMOTE].ctn, log)
self.name = name self.name = name
Logger.__init__(self) Logger.__init__(self)
@ -194,6 +194,7 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked = self.pending_commitment(REMOTE) self.remote_commitment_to_be_revoked = self.pending_commitment(REMOTE)
self.config[REMOTE] = self.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_pcp, next_per_commitment_point=None) self.config[REMOTE] = self.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_pcp, next_per_commitment_point=None)
self.config[LOCAL] = self.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) self.config[LOCAL] = self.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig)
self.hm.channel_open_finished()
self.set_state('OPENING') self.set_state('OPENING')
def set_force_closed(self): def set_force_closed(self):

39
electrum/lnhtlc.py

@ -10,6 +10,11 @@ class HTLCManager:
def __init__(self, local_ctn=0, remote_ctn=0, log=None): def __init__(self, local_ctn=0, remote_ctn=0, log=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub # self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn} self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
# self.ctn_latest[sub] is the ctn for the latest (newest that has a valid sig) ctx of sub
self.ctn_latest = {LOCAL:local_ctn, REMOTE: remote_ctn} # FIXME does this need to be persisted?
# after sending commitment_signed but before receiving revoke_and_ack,
# self.ctn_latest[REMOTE] == self.ctn[REMOTE] + 1
# otherwise they are equal
self.expect_sig = {SENT: False, RECEIVED: False} self.expect_sig = {SENT: False, RECEIVED: False}
if log is None: if log is None:
initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}} initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
@ -35,33 +40,39 @@ class HTLCManager:
x[sub]['adds'] = d x[sub]['adds'] = d
return x return x
def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0}
self.ctn_latest = {LOCAL: 0, REMOTE: 0}
def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
htlc_id = htlc.htlc_id htlc_id = htlc.htlc_id
adds = self.log[LOCAL]['adds'] adds = self.log[LOCAL]['adds']
assert type(adds) is not str assert type(adds) is not str
adds[htlc_id] = htlc adds[htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn[REMOTE]+1} self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest[REMOTE]+1}
self.expect_sig[SENT] = True self.expect_sig[SENT] = True
return htlc return htlc
def recv_htlc(self, htlc: UpdateAddHtlc) -> None: def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
htlc_id = htlc.htlc_id htlc_id = htlc.htlc_id
self.log[REMOTE]['adds'][htlc_id] = htlc self.log[REMOTE]['adds'][htlc_id] = htlc
l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn[LOCAL]+1, REMOTE: None} l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest[LOCAL]+1, REMOTE: None}
self.expect_sig[RECEIVED] = True self.expect_sig[RECEIVED] = True
def send_ctx(self) -> None: def send_ctx(self) -> None:
next_ctn = self.ctn[REMOTE] + 1 assert self.ctn_latest[REMOTE] == self.ctn[REMOTE], (self.ctn_latest[REMOTE], self.ctn[REMOTE])
self.ctn_latest[REMOTE] = self.ctn[REMOTE] + 1
for locked_in in self.log[REMOTE]['locked_in'].values(): for locked_in in self.log[REMOTE]['locked_in'].values():
if locked_in[REMOTE] is None: if locked_in[REMOTE] is None:
locked_in[REMOTE] = next_ctn locked_in[REMOTE] = self.ctn_latest[REMOTE]
self.expect_sig[SENT] = False self.expect_sig[SENT] = False
def recv_ctx(self) -> None: def recv_ctx(self) -> None:
next_ctn = self.ctn[LOCAL] + 1 assert self.ctn_latest[LOCAL] == self.ctn[LOCAL], (self.ctn_latest[LOCAL], self.ctn[LOCAL])
self.ctn_latest[LOCAL] = self.ctn[LOCAL] + 1
for locked_in in self.log[LOCAL]['locked_in'].values(): for locked_in in self.log[LOCAL]['locked_in'].values():
if locked_in[LOCAL] is None: if locked_in[LOCAL] is None:
locked_in[LOCAL] = next_ctn locked_in[LOCAL] = self.ctn_latest[LOCAL]
self.expect_sig[RECEIVED] = False self.expect_sig[RECEIVED] = False
def send_rev(self) -> None: def send_rev(self) -> None:
@ -69,18 +80,18 @@ class HTLCManager:
for log_action in ('settles', 'fails'): for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[LOCAL][log_action].items(): for htlc_id, ctns in self.log[LOCAL][log_action].items():
if ctns[REMOTE] is None: if ctns[REMOTE] is None:
ctns[REMOTE] = self.ctn[REMOTE] + 1 ctns[REMOTE] = self.ctn_latest[REMOTE] + 1
def recv_rev(self) -> None: def recv_rev(self) -> None:
self.ctn[REMOTE] += 1 self.ctn[REMOTE] += 1
for htlc_id, ctns in self.log[LOCAL]['locked_in'].items(): for htlc_id, ctns in self.log[LOCAL]['locked_in'].items():
if ctns[LOCAL] is None: if ctns[LOCAL] is None:
assert ctns[REMOTE] == self.ctn[REMOTE] #assert ctns[REMOTE] == self.ctn[REMOTE] # FIXME I don't think this assert is correct
ctns[LOCAL] = self.ctn[LOCAL] + 1 ctns[LOCAL] = self.ctn_latest[LOCAL] + 1
for log_action in ('settles', 'fails'): for log_action in ('settles', 'fails'):
for htlc_id, ctns in self.log[REMOTE][log_action].items(): for htlc_id, ctns in self.log[REMOTE][log_action].items():
if ctns[LOCAL] is None: if ctns[LOCAL] is None:
ctns[LOCAL] = self.ctn[LOCAL] + 1 ctns[LOCAL] = self.ctn_latest[LOCAL] + 1
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]: ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -136,10 +147,10 @@ class HTLCManager:
return self.htlcs(subject, ctn) return self.htlcs(subject, ctn)
def send_settle(self, htlc_id: int) -> None: def send_settle(self, htlc_id: int) -> None:
self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn[REMOTE] + 1} self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest[REMOTE] + 1}
def recv_settle(self, htlc_id: int) -> None: def recv_settle(self, htlc_id: int) -> None:
self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn[LOCAL] + 1, REMOTE: None} self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest[LOCAL] + 1, REMOTE: None}
def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction, def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]: ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -181,7 +192,7 @@ class HTLCManager:
if ctns[LOCAL] == ctn] if ctns[LOCAL] == ctn]
def send_fail(self, htlc_id: int) -> None: def send_fail(self, htlc_id: int) -> None:
self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn[REMOTE] + 1} self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest[REMOTE] + 1}
def recv_fail(self, htlc_id: int) -> None: def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn[LOCAL] + 1, REMOTE: None} self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest[LOCAL] + 1, REMOTE: None}

3
electrum/lnpeer.py

@ -1071,7 +1071,7 @@ class Peer(Logger):
remote_ctn = chan.get_current_ctn(REMOTE) remote_ctn = chan.get_current_ctn(REMOTE)
chan.onion_keys[htlc.htlc_id] = secret_key chan.onion_keys[htlc.htlc_id] = secret_key
self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route
self.logger.info(f"starting payment. route: {route}") self.logger.info(f"starting payment. route: {route}. htlc: {htlc}")
self.send_message("update_add_htlc", self.send_message("update_add_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,
id=htlc.htlc_id, id=htlc.htlc_id,
@ -1271,6 +1271,7 @@ class Peer(Logger):
self._remote_changed_events[chan.channel_id].set() self._remote_changed_events[chan.channel_id].set()
self._remote_changed_events[chan.channel_id].clear() self._remote_changed_events[chan.channel_id].clear()
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self.maybe_send_commitment(chan)
def on_update_fee(self, payload): def on_update_fee(self, payload):
channel_id = payload["channel_id"] channel_id = payload["channel_id"]

2
electrum/tests/test_lnchannel.py

@ -164,6 +164,8 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0) alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0)
bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0) bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0)
alice.hm.channel_open_finished()
bob.hm.channel_open_finished()
return alice, bob return alice, bob

21
electrum/tests/test_lnhtlc.py

@ -1,6 +1,6 @@
from pprint import pprint from pprint import pprint
import unittest import unittest
from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction
from electrum.lnhtlc import HTLCManager from electrum.lnhtlc import HTLCManager
from typing import NamedTuple from typing import NamedTuple
@ -135,3 +135,22 @@ class TestHTLCManager(unittest.TestCase):
htlc_lifecycle(htlc_success=True) htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False) htlc_lifecycle(htlc_success=False)
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
A = HTLCManager()
B = HTLCManager()
A.send_ctx()
B.recv_ctx()
B.send_rev()
ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
self.assertEqual([], A.current_htlcs(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE))
self.assertEqual([], A.pending_htlcs(LOCAL))
self.assertEqual([], A.pending_htlcs(REMOTE))
A.recv_rev()
self.assertEqual([], A.current_htlcs(LOCAL))
self.assertEqual([], A.current_htlcs(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.pending_htlcs(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.pending_htlcs(REMOTE))

Loading…
Cancel
Save