diff --git a/electrum/crypto.py b/electrum/crypto.py index 70192eaf0..645d0561f 100644 --- a/electrum/crypto.py +++ b/electrum/crypto.py @@ -229,21 +229,29 @@ def pw_decode_bytes(data: str, password: Union[bytes, str], *, version:int) -> b return pw_decode_raw(data_bytes, password, version=version) -def pw_encode_b64_with_version(data: bytes, password: Union[bytes, str]) -> str: +def pw_encode_with_version_and_mac(data: bytes, password: Union[bytes, str]) -> str: """plaintext bytes -> base64 ciphertext""" + # https://crypto.stackexchange.com/questions/202/should-we-mac-then-encrypt-or-encrypt-then-mac + # Encrypt-and-MAC. The MAC will be used to detect invalid passwords version = PW_HASH_VERSION_LATEST + mac = sha256(data)[0:4] ciphertext = pw_encode_raw(data, password, version=version) - ciphertext_b64 = base64.b64encode(bytes([version]) + ciphertext) + ciphertext_b64 = base64.b64encode(bytes([version]) + ciphertext + mac) return ciphertext_b64.decode('utf8') -def pw_decode_b64_with_version(data: str, password: Union[bytes, str]) -> bytes: +def pw_decode_with_version_and_mac(data: str, password: Union[bytes, str]) -> bytes: """base64 ciphertext -> plaintext bytes""" data_bytes = bytes(base64.b64decode(data)) version = int(data_bytes[0]) + encrypted = data_bytes[1:-4] + mac = data_bytes[-4:] if version not in KNOWN_PW_HASH_VERSIONS: raise UnexpectedPasswordHashVersion(version) - return pw_decode_raw(data_bytes[1:], password, version=version) + decrypted = pw_decode_raw(encrypted, password, version=version) + if sha256(decrypted)[0:4] != mac: + raise InvalidPassword() + return decrypted def pw_encode(data: str, password: Union[bytes, str, None], *, version: int) -> str: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 99f01f21a..4bc956674 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -66,8 +66,7 @@ from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_sane_t from .address_synchronizer import TX_HEIGHT_LOCAL from . import lnsweep from .lnwatcher import LNWalletWatcher -from .crypto import pw_encode_bytes, pw_decode_bytes, PW_HASH_VERSION_LATEST -from .crypto import pw_encode_b64_with_version, pw_decode_b64_with_version +from .crypto import pw_encode_with_version_and_mac, pw_decode_with_version_and_mac from .lnutil import ChannelBackupStorage from .lnchannel import ChannelBackup from .channel_db import UpdateStatus @@ -1397,8 +1396,8 @@ class LNWallet(LNWorker): xpub = self.wallet.get_fingerprint() backup_bytes = self.create_channel_backup(channel_id).to_bytes() assert backup_bytes == ChannelBackupStorage.from_bytes(backup_bytes).to_bytes(), "roundtrip failed" - encrypted = pw_encode_b64_with_version(backup_bytes, xpub) - assert backup_bytes == pw_decode_b64_with_version(encrypted, xpub), "encrypt failed" + encrypted = pw_encode_with_version_and_mac(backup_bytes, xpub) + assert backup_bytes == pw_decode_with_version_and_mac(encrypted, xpub), "encrypt failed" return 'channel_backup:' + encrypted @@ -1454,7 +1453,7 @@ class LNBackups(Logger): assert data.startswith('channel_backup:') encrypted = data[15:] xpub = self.wallet.get_fingerprint() - decrypted = pw_decode_b64_with_version(encrypted, xpub) + decrypted = pw_decode_with_version_and_mac(encrypted, xpub) cb_storage = ChannelBackupStorage.from_bytes(decrypted) channel_id = cb_storage.channel_id().hex() d = self.db.get_dict("channel_backups")