Skip to content

Commit 90590f0

Browse files
authored
Sem-Ver: feature Reduce the time taken to generate a jwt by caching loaded private key instances
* Sem-Ver: feature Reduce the time taken to generate a jwt by passing through a load private key instance to jwt.encode. Signed-off-by: David Black <dblack@atlassian.com> * Sem-Ver: feature Introduce a private keys cache so as to support PrivateKeyRetriever that can/may provide different load method results. Signed-off-by: David Black <dblack@atlassian.com>
1 parent d8ba53a commit 90590f0

2 files changed

Lines changed: 34 additions & 2 deletions

File tree

atlassian_jwt_auth/signer.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import random
44

55
import jwt
6+
from cryptography.hazmat.backends import default_backend
7+
from cryptography.hazmat.primitives import serialization
68

79
from atlassian_jwt_auth import algorithms
810
from atlassian_jwt_auth import key
@@ -16,6 +18,7 @@ def __init__(self, issuer, private_key_retriever, **kwargs):
1618
self.lifetime = kwargs.get('lifetime', datetime.timedelta(hours=1))
1719
self.algorithm = kwargs.get('algorithm', 'RS256')
1820
self.subject = kwargs.get('subject', None)
21+
self._private_keys_cache = dict()
1922

2023
if self.algorithm not in set(
2124
algorithms.get_permitted_algorithm_names()):
@@ -25,6 +28,25 @@ def __init__(self, issuer, private_key_retriever, **kwargs):
2528
raise ValueError("lifetime, '%s',exceeds the allowed 1 hour max" %
2629
(self.lifetime))
2730

31+
def _obtain_private_key(self, key_identifier, private_key_pem):
32+
""" returns a loaded instance of the given private key either from
33+
cache or from the given private_key_pem.
34+
"""
35+
priv_key = self._private_keys_cache.get(key_identifier.key_id, None)
36+
if priv_key is not None:
37+
return priv_key
38+
if not isinstance(private_key_pem, bytes):
39+
private_key_pem = private_key_pem.encode()
40+
priv_key = serialization.load_pem_private_key(
41+
private_key_pem,
42+
password=None,
43+
backend=default_backend()
44+
)
45+
if len(self._private_keys_cache) > 10:
46+
self._private_keys_cache = dict()
47+
self._private_keys_cache[key_identifier.key_id] = priv_key
48+
return priv_key
49+
2850
def _generate_claims(self, audience, **kwargs):
2951
""" returns a new dictionary of claims. """
3052
now = self._now()
@@ -48,9 +70,11 @@ def generate_jwt(self, audience, **kwargs):
4870
""" returns a new signed jwt for use. """
4971
key_identifier, private_key_pem = self.private_key_retriever.load(
5072
self.issuer)
73+
private_key = self._obtain_private_key(
74+
key_identifier, private_key_pem)
5175
return jwt.encode(
5276
self._generate_claims(audience, **kwargs),
53-
key=private_key_pem,
77+
key=private_key,
5478
algorithm=self.algorithm,
5579
headers={'kid': key_identifier.key_id})
5680

atlassian_jwt_auth/tests/test_signer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33

44
import mock
5+
from cryptography.hazmat.primitives import serialization
56

67
import atlassian_jwt_auth
78
from atlassian_jwt_auth.tests import utils
@@ -72,9 +73,16 @@ def test_generate_jwt(self, m_jwt_encode):
7273
jwt_auth_signer.generate_jwt(expected_aud)
7374
m_jwt_encode.assert_called_with(
7475
expected_claims,
75-
key=self._private_key_pem,
76+
key=mock.ANY,
7677
algorithm=self.algorithm,
7778
headers={'kid': expected_key_id})
79+
for name, args, kwargs in m_jwt_encode.mock_calls:
80+
call_private_key = kwargs['key'].private_bytes(
81+
encoding=serialization.Encoding.PEM,
82+
format=serialization.PrivateFormat.TraditionalOpenSSL,
83+
encryption_algorithm=serialization.NoEncryption()
84+
)
85+
self.assertEqual(call_private_key, self._private_key_pem)
7886

7987

8088
class JWTAuthSignerRS256Test(

0 commit comments

Comments
 (0)