Skip to content

Commit 256d1ae

Browse files
committed
Added caching of temporary AWS credentials on EC2 instances with IAM roles
1 parent 567d246 commit 256d1ae

4 files changed

Lines changed: 236 additions & 0 deletions

File tree

src/bd2k/util/ec2/__init__.py

Whitespace-only changes.

src/bd2k/util/ec2/credentials.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import errno
2+
import logging
3+
import threading
4+
import time
5+
from datetime import datetime
6+
7+
import os
8+
from bd2k.util.files import mkdir_p
9+
10+
log = logging.getLogger( __name__ )
11+
12+
cache_path = '~/.cache/aws/cached_temporary_credentials'
13+
14+
datetime_format = "%Y-%m-%dT%H:%M:%SZ" # incidentally the same as the format used by AWS
15+
16+
17+
def datetime_to_str( dt ):
18+
"""
19+
Convert a naive (implicitly UTC) datetime object into a string, explicitly UTC.
20+
21+
>>> datetime_to_str( datetime( 1970, 1, 1, 0, 0, 0 ) )
22+
'1970-01-01T00:00:00Z'
23+
"""
24+
return dt.strftime( datetime_format )
25+
26+
27+
def str_to_datetime( s ):
28+
"""
29+
Convert a string, explicitly UTC into a naive (implicitly UTC) datetime object.
30+
31+
>>> str_to_datetime( '1970-01-01T00:00:00Z' )
32+
datetime.datetime(1970, 1, 1, 0, 0)
33+
34+
Just to show that the constructor args for seconds and microseconds are optional:
35+
>>> datetime(1970, 1, 1, 0, 0, 0)
36+
datetime.datetime(1970, 1, 1, 0, 0)
37+
"""
38+
return datetime.strptime( s, datetime_format )
39+
40+
41+
monkey_patch_lock = threading.RLock( )
42+
_populate_keys_from_metadata_server_orig = None
43+
44+
45+
def enable_metadata_credential_caching( ):
46+
"""
47+
Monkey-patches Boto to allow multiple processes using it to share one set of cached, temporary
48+
IAM role credentials. This helps avoid hitting request limits imposed on the metadata service
49+
when too many processes concurrently request those credentials. Function is idempotent.
50+
51+
This function should be called before any AWS connections attempts are made with Boto.
52+
"""
53+
global _populate_keys_from_metadata_server_orig
54+
with monkey_patch_lock:
55+
if _populate_keys_from_metadata_server_orig is None:
56+
from boto.provider import Provider
57+
_populate_keys_from_metadata_server_orig = Provider._populate_keys_from_metadata_server
58+
Provider._populate_keys_from_metadata_server = _populate_keys_from_metadata_server
59+
60+
61+
def disable_metadata_credential_caching( ):
62+
"""
63+
Reverse the effect of enable_metadata_credential_caching()
64+
"""
65+
global _populate_keys_from_metadata_server_orig
66+
with monkey_patch_lock:
67+
if _populate_keys_from_metadata_server_orig is not None:
68+
from boto.provider import Provider
69+
Provider._populate_keys_from_metadata_server = _populate_keys_from_metadata_server_orig
70+
_populate_keys_from_metadata_server_orig = None
71+
72+
73+
def _populate_keys_from_metadata_server( self ):
74+
global _populate_keys_from_metadata_server_orig
75+
path = os.path.expanduser( cache_path )
76+
tmp_path = path + '.tmp'
77+
while True:
78+
log.debug( 'Attempting to read cached credentials from %s.', path )
79+
try:
80+
with open( path, 'r' ) as f:
81+
record = f.read( ).split( '\n' )
82+
if record:
83+
self._access_key = record[ 0 ]
84+
self._secret_key = record[ 1 ]
85+
self._security_token = record[ 2 ]
86+
self._credential_expiry_time = str_to_datetime( record[ 3 ] )
87+
else:
88+
log.debug( '%s is empty. Credentials are not temporary.', path )
89+
return
90+
except IOError as e:
91+
if e.errno == errno.ENOENT:
92+
log.debug( 'Cached credentials are missing.' )
93+
dir_path = os.path.dirname( path )
94+
if not os.path.exists( dir_path ):
95+
log.debug( 'Creating parent directory %s', dir_path )
96+
# A race would be ok at this point
97+
mkdir_p( dir_path )
98+
else:
99+
raise
100+
else:
101+
if self._credentials_need_refresh( ):
102+
log.debug( 'Cached credentials are expired.' )
103+
else:
104+
log.debug( 'Cached credentials exist and are still fresh.' )
105+
return
106+
# We get here if credentials are missing or expired
107+
log.debug( 'Racing to create %s.', tmp_path )
108+
# Only one process, the winner, will succeed
109+
try:
110+
fd = os.open( tmp_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0600 )
111+
except OSError as e:
112+
if e.errno == errno.EEXIST:
113+
log.debug( 'Lost the race to create %s. Waiting on winner to remove it.', tmp_path )
114+
while os.path.exists( tmp_path ):
115+
time.sleep( .1 )
116+
log.debug( 'Winner removed %s. Trying from the top.', tmp_path )
117+
else:
118+
raise
119+
else:
120+
try:
121+
log.debug( 'Won the race to create %s. '
122+
'Requesting credentials from metadata service.', tmp_path )
123+
_populate_keys_from_metadata_server_orig( self )
124+
except:
125+
os.close( fd )
126+
fd = None
127+
log.debug( 'Failed to obtain credentials, removing %s.', tmp_path )
128+
# This unblocks the loosers.
129+
os.unlink( tmp_path )
130+
# Bail out. It's too likely to happen repeatedly
131+
raise
132+
else:
133+
if self._credential_expiry_time is None:
134+
os.close( fd )
135+
fd = None
136+
log.debug( 'Credentials are not temporary. '
137+
'Leaving %s empty and renaming it to %s.', tmp_path, path )
138+
else:
139+
log.debug( 'Writing credentials to %s.', tmp_path )
140+
with os.fdopen( fd, 'w' ) as fh:
141+
fd = None
142+
fh.write( '\n'.join( [
143+
self._access_key,
144+
self._secret_key,
145+
self._security_token,
146+
datetime_to_str( self._credential_expiry_time ) ] ) )
147+
log.debug( 'Wrote credentials to %s. '
148+
'Renaming it to %s.', tmp_path, path )
149+
os.rename( tmp_path, path )
150+
return
151+
finally:
152+
if fd is not None:
153+
os.close( fd )

src/bd2k/util/ec2/test/__init__.py

Whitespace-only changes.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import logging
2+
3+
import errno
4+
5+
import os
6+
import unittest
7+
8+
from bd2k.util.ec2.credentials import (enable_metadata_credential_caching,
9+
disable_metadata_credential_caching, cache_path)
10+
11+
12+
def get_access_key( ):
13+
from boto.provider import Provider
14+
provider = Provider( 'aws' )
15+
return provider.get_access_key( )
16+
17+
18+
class CredentialsTest( unittest.TestCase ):
19+
def __init__( self, *args, **kwargs ):
20+
super( CredentialsTest, self ).__init__( *args, **kwargs )
21+
self.cache_path = os.path.expanduser( cache_path )
22+
23+
@classmethod
24+
def setUpClass( cls ):
25+
super( CredentialsTest, cls ).setUpClass( )
26+
logging.basicConfig( level=logging.DEBUG )
27+
28+
def setUp( self ):
29+
super( CredentialsTest, self ).setUp( )
30+
self.cleanUp( )
31+
32+
def cleanUp( self ):
33+
try:
34+
os.unlink( self.cache_path )
35+
except OSError as e:
36+
if e.errno == errno.ENOENT:
37+
pass
38+
else:
39+
raise
40+
41+
def tearDown( self ):
42+
super( CredentialsTest, self ).tearDown( )
43+
self.cleanUp( )
44+
45+
def test_metadata_credential_caching( self ):
46+
"""
47+
Brute forces many concurrent requests for getting temporary credentials. If you comment
48+
out the calls to enable_metadata_credential_caching, you should see some failures due to
49+
requests timing out. The test will also take much longer in that case.
50+
"""
51+
num_tests = 1000
52+
num_processes = 32
53+
# Get key without caching
54+
access_key = get_access_key( )
55+
self.assertFalse( os.path.exists( self.cache_path ) )
56+
enable_metadata_credential_caching( )
57+
# Again for idempotence
58+
enable_metadata_credential_caching( )
59+
try:
60+
futures = [ ]
61+
from multiprocessing import Pool
62+
pool = Pool( num_processes )
63+
try:
64+
for i in range( num_tests ):
65+
futures.append( pool.apply_async( get_access_key ) )
66+
except:
67+
pool.close( )
68+
pool.terminate( )
69+
raise
70+
else:
71+
pool.close( )
72+
pool.join( )
73+
finally:
74+
disable_metadata_credential_caching( )
75+
# Again for idempotence
76+
disable_metadata_credential_caching( )
77+
self.assertTrue( os.path.exists( self.cache_path ) )
78+
self.assertEquals( len( futures ), num_tests )
79+
access_keys = [ f.get( ) for f in futures ]
80+
self.assertEquals( len( access_keys ), num_tests )
81+
access_keys = set( access_keys )
82+
self.assertEquals( len( access_keys ), 1 )
83+
self.assertEquals( access_keys.pop( ), access_key )

0 commit comments

Comments
 (0)