8 changed files with 697 additions and 130 deletions
@ -0,0 +1,100 @@
|
||||
import datetime |
||||
import os |
||||
|
||||
import jwt |
||||
|
||||
from jmbase.support import bintohex |
||||
|
||||
|
||||
class InvalidScopeError(Exception): |
||||
pass |
||||
|
||||
|
||||
class ExpiredSignatureError(jwt.exceptions.ExpiredSignatureError): |
||||
pass |
||||
|
||||
|
||||
def get_random_key(size: int = 16) -> str: |
||||
"""Create a random key has an hexadecimal string.""" |
||||
return bintohex(os.urandom(size)) |
||||
|
||||
|
||||
class JMTokenAuthority: |
||||
"""Manage authorization tokens.""" |
||||
|
||||
SESSION_VALIDITY = { |
||||
"access": datetime.timedelta(minutes=30), |
||||
"refresh": datetime.timedelta(hours=4), |
||||
} |
||||
SIGNATURE_ALGORITHM = "HS256" |
||||
|
||||
def __init__(self, *wallet_names: str): |
||||
self.signature_key = { |
||||
"access": get_random_key(), |
||||
"refresh": get_random_key(), |
||||
} |
||||
self._scope = {"walletrpc"} |
||||
for wallet_name in wallet_names: |
||||
self.add_to_scope(wallet_name) |
||||
|
||||
def verify(self, token: str, *, token_type: str = "access"): |
||||
"""Verify JWT token. |
||||
|
||||
Token must have a valid signature and its scope must contain both scopes in |
||||
arguments and wallet_name property. |
||||
""" |
||||
try: |
||||
claims = jwt.decode( |
||||
token, |
||||
self.signature_key[token_type], |
||||
algorithms=self.SIGNATURE_ALGORITHM, |
||||
leeway=10, |
||||
) |
||||
except jwt.exceptions.ExpiredSignatureError: |
||||
raise ExpiredSignatureError |
||||
|
||||
token_claims = set(claims.get("scope", []).split()) |
||||
if not self._scope <= token_claims: |
||||
raise InvalidScopeError |
||||
|
||||
def add_to_scope(self, *args: str): |
||||
for arg in args: |
||||
self._scope.add(arg) |
||||
|
||||
def discard_from_scope(self, *args: str): |
||||
for arg in args: |
||||
self._scope.discard(arg) |
||||
|
||||
@property |
||||
def scope(self): |
||||
return " ".join(self._scope) |
||||
|
||||
def _issue(self, token_type: str) -> str: |
||||
return jwt.encode( |
||||
{ |
||||
"exp": datetime.datetime.utcnow() + self.SESSION_VALIDITY[token_type], |
||||
"scope": self.scope, |
||||
}, |
||||
self.signature_key[token_type], |
||||
algorithm=self.SIGNATURE_ALGORITHM, |
||||
) |
||||
|
||||
def issue(self) -> dict: |
||||
"""Issue a new access and refresh token. |
||||
Previously issued refresh token is invalidated. |
||||
""" |
||||
self.signature_key["refresh"] = get_random_key() |
||||
return { |
||||
"token": self._issue("access"), |
||||
"token_type": "bearer", |
||||
"expires_in": int(self.SESSION_VALIDITY["access"].total_seconds()), |
||||
"scope": self.scope, |
||||
"refresh_token": self._issue("refresh"), |
||||
} |
||||
|
||||
def reset(self): |
||||
"""Invalidate all previously issued tokens by creating new signature keys.""" |
||||
self.signature_key = { |
||||
"access": get_random_key(), |
||||
"refresh": get_random_key(), |
||||
} |
||||
@ -0,0 +1,97 @@
|
||||
"""test auth module.""" |
||||
|
||||
import copy |
||||
import datetime |
||||
|
||||
import jwt |
||||
import pytest |
||||
|
||||
from jmclient.auth import ExpiredSignatureError, InvalidScopeError, JMTokenAuthority |
||||
|
||||
|
||||
class TestJMTokenAuthority: |
||||
wallet_name = "dummywallet" |
||||
token_auth = JMTokenAuthority(wallet_name) |
||||
|
||||
access_sig = copy.copy(token_auth.signature_key["access"]) |
||||
refresh_sig = copy.copy(token_auth.signature_key["refresh"]) |
||||
|
||||
validity = datetime.timedelta(hours=1) |
||||
scope = f"walletrpc {wallet_name}" |
||||
|
||||
@pytest.mark.parametrize( |
||||
"sig, token_type", [(access_sig, "access"), (refresh_sig, "refresh")] |
||||
) |
||||
def test_verify_valid(self, sig, token_type): |
||||
token = jwt.encode( |
||||
{"exp": datetime.datetime.utcnow() + self.validity, "scope": self.scope}, |
||||
sig, |
||||
algorithm=self.token_auth.SIGNATURE_ALGORITHM, |
||||
) |
||||
|
||||
try: |
||||
self.token_auth.verify(token, token_type=token_type) |
||||
except Exception as e: |
||||
print(e) |
||||
pytest.fail("Token verification failed, token is valid.") |
||||
|
||||
def test_verify_expired(self): |
||||
token = jwt.encode( |
||||
{"exp": datetime.datetime.utcnow() - self.validity, "scope": self.scope}, |
||||
self.access_sig, |
||||
algorithm=self.token_auth.SIGNATURE_ALGORITHM, |
||||
) |
||||
|
||||
with pytest.raises(ExpiredSignatureError): |
||||
self.token_auth.verify(token) |
||||
|
||||
def test_verify_non_scoped(self): |
||||
token = jwt.encode( |
||||
{"exp": datetime.datetime.utcnow() + self.validity, "scope": "wrong"}, |
||||
self.access_sig, |
||||
algorithm=self.token_auth.SIGNATURE_ALGORITHM, |
||||
) |
||||
|
||||
with pytest.raises(InvalidScopeError): |
||||
self.token_auth.verify(token) |
||||
|
||||
def test_issue(self): |
||||
def scope_equals(scope): |
||||
return set(scope.split(" ")) == set(self.scope.split(" ")) |
||||
|
||||
token_response = self.token_auth.issue() |
||||
|
||||
assert token_response.pop("expires_in") == int( |
||||
self.token_auth.SESSION_VALIDITY["access"].total_seconds() |
||||
) |
||||
assert token_response.pop("token_type") == "bearer" |
||||
assert scope_equals(token_response.pop("scope")) |
||||
|
||||
try: |
||||
for k, v in token_response.items(): |
||||
claims = jwt.decode( |
||||
v, |
||||
self.token_auth.signature_key["refresh"] |
||||
if k == "refresh_token" |
||||
else self.token_auth.signature_key["access"], |
||||
algorithms=self.token_auth.SIGNATURE_ALGORITHM, |
||||
) |
||||
assert scope_equals(claims.get("scope")) |
||||
assert self.token_auth.signature_key["refresh"] != self.refresh_sig |
||||
except jwt.exceptions.InvalidTokenError: |
||||
pytest.fail("An invalid token was issued.") |
||||
|
||||
def test_scope_operation(self): |
||||
assert "walletrpc" in self.token_auth._scope |
||||
assert self.wallet_name in self.token_auth._scope |
||||
|
||||
scope = copy.copy(self.token_auth._scope) |
||||
s = "new_wallet" |
||||
|
||||
self.token_auth.add_to_scope(s) |
||||
assert scope < self.token_auth._scope |
||||
assert s in self.token_auth._scope |
||||
|
||||
self.token_auth.discard_from_scope(s, "walletrpc") |
||||
assert scope > self.token_auth._scope |
||||
assert s not in self.token_auth._scope |
||||
Loading…
Reference in new issue