__init__.py 35.4 KB
Newer Older
Max Beckett's avatar
Max Beckett committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
#
# Copyright 2017-2018 Amazon.com, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the MIT License. See the LICENSE accompanying this file
# for the specific language governing permissions and limitations under
# the License.
#

10
import base64
Max Beckett's avatar
Max Beckett committed
11
import errno
12
13
import hashlib
import hmac
Max Beckett's avatar
Max Beckett committed
14
15
16
17
import json
import logging
import logging.handlers
import os
18
19
20
import pwd
import re
import shutil
Max Beckett's avatar
Max Beckett committed
21
22
23
24
25
import subprocess
import sys
import time

from collections import namedtuple
26
27
from contextlib import contextmanager
from datetime import datetime, timedelta
Max Beckett's avatar
Max Beckett committed
28
from logging.handlers import RotatingFileHandler
29
from signal import SIGTERM, SIGHUP
Max Beckett's avatar
Max Beckett committed
30
31
32

try:
    import ConfigParser
33
    from ConfigParser import NoOptionError, NoSectionError
Max Beckett's avatar
Max Beckett committed
34
except ImportError:
35
    from configparser import ConfigParser, NoOptionError, NoSectionError
Max Beckett's avatar
Max Beckett committed
36

37
38
39
40
41
42
43
44
45
46
47
try:
    from urllib.parse import quote_plus
except ImportError:
    from urllib import quote_plus

try:
    from urllib2 import urlopen, URLError
except ImportError:
    from urllib.error import URLError
    from urllib.request import urlopen

48
VERSION = '1.21'
49
SERVICE = 'elasticfilesystem'
Max Beckett's avatar
Max Beckett committed
50
51
52
53
54
55
56
57
58

CONFIG_FILE = '/etc/amazon/efs/efs-utils.conf'
CONFIG_SECTION = 'mount-watchdog'

LOG_DIR = '/var/log/amazon/efs'
LOG_FILE = 'mount-watchdog.log'

STATE_FILE_DIR = '/var/run/efs'

59
60
61
62
63
64
65
66
PRIVATE_KEY_FILE = '/etc/amazon/efs/privateKey.pem'
REFRESH_SELF_SIGNED_CERT_INTERVAL_SEC = 60 * 60
NOT_BEFORE_MINS = 15
NOT_AFTER_HOURS = 3
DATE_ONLY_FORMAT = '%Y%m%d'
SIGV4_DATETIME_FORMAT = '%Y%m%dT%H%M%SZ'
CERT_DATETIME_FORMAT = '%y%m%d%H%M%SZ'

67
68
69
70
71
AWS_CREDENTIALS_FILES = {
    'credentials': os.path.expanduser(os.path.join('~' + pwd.getpwuid(os.getuid()).pw_name, '.aws', 'credentials')),
    'config': os.path.expanduser(os.path.join('~' + pwd.getpwuid(os.getuid()).pw_name, '.aws', 'config')),
}

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
CA_CONFIG_BODY = """dir = %s
RANDFILE = $dir/database/.rand

[ ca ]
default_ca = local_ca

[ local_ca ]
database = $dir/database/index.txt
serial = $dir/database/serial
private_key = %s
cert = $dir/certificate.pem
new_certs_dir = $dir/certs
default_md = sha256
preserve = no
policy = efsPolicy
x509_extensions = v3_ca

[ efsPolicy ]
CN = supplied

[ req ]
prompt = no
distinguished_name = req_distinguished_name

[ req_distinguished_name ]
CN = %s

%s

%s
"""

# SigV4 Auth
ALGORITHM = 'AWS4-HMAC-SHA256'
AWS4_REQUEST = 'aws4_request'

HTTP_REQUEST_METHOD = 'GET'
CANONICAL_URI = '/'
CANONICAL_HEADERS_DICT = {
    'host': '%s'
}
CANONICAL_HEADERS = '\n'.join(['%s:%s' % (k, v) for k, v in sorted(CANONICAL_HEADERS_DICT.items())])
SIGNED_HEADERS = ';'.join(CANONICAL_HEADERS_DICT.keys())
REQUEST_PAYLOAD = ''

AP_ID_RE = re.compile('^fsap-[0-9a-f]{17}$')

119
ECS_TASK_METADATA_API = '169.254.170.2'
120
121
122
123
INSTANCE_IAM_URL = 'http://169.254.169.254/latest/meta-data/iam/security-credentials/'
SECURITY_CREDS_ECS_URI_HELP_URL = 'https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html'
SECURITY_CREDS_IAM_ROLE_HELP_URL = 'https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html'

Max Beckett's avatar
Max Beckett committed
124
125
126
127
128
129
130
131
132
133
134
135
Mount = namedtuple('Mount', ['server', 'mountpoint', 'type', 'options', 'freq', 'passno'])


def fatal_error(user_message, log_message=None):
    if log_message is None:
        log_message = user_message

    sys.stderr.write('%s\n' % user_message)
    logging.error(log_message)
    sys.exit(1)


136
def get_aws_security_credentials(credentials_source):
137
138
139
140
141
    """
    Lookup AWS security credentials (access key ID and secret access key). Adapted credentials provider chain from:
    https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html and
    https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html
    """
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    method, value = credentials_source.split(':', 1)

    if method == 'credentials':
        return get_aws_security_credentials_from_file('credentials', value)
    elif method == 'config':
        return get_aws_security_credentials_from_file('config', value)
    elif method == 'ecs':
        return get_aws_security_credentials_from_ecs(value)
    elif method == 'metadata':
        return get_aws_security_credentials_from_instance_metadata()
    else:
        logging.error('Improper credentials source string "%s" found from mount state file', credentials_source)
        return None

156

157
158
159
160
161
def get_aws_security_credentials_from_file(file_name, awsprofile):
    # attempt to lookup AWS security credentials in AWS credentials file (~/.aws/credentials) and configs file (~/.aws/config)
    file_path = AWS_CREDENTIALS_FILES.get(file_name)
    if file_path and os.path.exists(file_path):
        credentials = credentials_file_helper(file_path, awsprofile)
162
163
164
        if credentials['AccessKeyId']:
            return credentials

165
166
    logging.error('AWS security credentials not found in %s under named profile [%s]', file_path, awsprofile)
    return None
167
168


169
def get_aws_security_credentials_from_ecs(uri):
170
    # through ECS security credentials uri found in AWS_CONTAINER_CREDENTIALS_RELATIVE_URI environment variable
171
172
173
174
175
176
177
178
179
180
181
    dict_keys = ['AccessKeyId', 'SecretAccessKey', 'Token']
    ecs_uri = ECS_TASK_METADATA_API + uri
    ecs_unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % ecs_uri
    ecs_url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.', ecs_uri, \
                        SECURITY_CREDS_ECS_URI_HELP_URL
    ecs_security_dict = url_request_helper(ecs_uri, ecs_unsuccessful_resp, ecs_url_error_msg)

    if ecs_security_dict and all(k in ecs_security_dict for k in dict_keys):
        return ecs_security_dict

    return None
182
183


184
def get_aws_security_credentials_from_instance_metadata():
185
    # through IAM role name security credentials lookup uri (after lookup for IAM role name attached to instance)
186
    dict_keys = ['AccessKeyId', 'SecretAccessKey', 'Token']
187
188
189
190
191
192
193
194
195
196
197
    iam_role_unsuccessful_resp = 'Unsuccessful retrieval of IAM role name at %s.' % INSTANCE_IAM_URL
    iam_role_url_error_msg = 'Unable to reach %s to retrieve IAM role name. See %s for more info.', INSTANCE_IAM_URL, \
                             SECURITY_CREDS_IAM_ROLE_HELP_URL
    iam_role_name = url_request_helper(INSTANCE_IAM_URL, iam_role_unsuccessful_resp, iam_role_url_error_msg)
    if iam_role_name:
        security_creds_lookup_url = INSTANCE_IAM_URL + str(iam_role_name)
        unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % security_creds_lookup_url
        url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.', \
                        security_creds_lookup_url, SECURITY_CREDS_IAM_ROLE_HELP_URL
        iam_security_dict = url_request_helper(security_creds_lookup_url, unsuccessful_resp, url_error_msg)

198
        if iam_security_dict and all(k in iam_security_dict for k in dict_keys):
199
200
            return iam_security_dict

201
    return None
202
203


204
def credentials_file_helper(file_path, awsprofile):
205
206
207
208
    aws_credentials_configs = read_config(file_path)
    credentials = {'AccessKeyId': None, 'SecretAccessKey': None, 'Token': None}

    try:
209
210
211
        aws_access_key_id = aws_credentials_configs.get(awsprofile, 'aws_access_key_id')
        secret_access_key = aws_credentials_configs.get(awsprofile, 'aws_secret_access_key')
        session_token = aws_credentials_configs.get(awsprofile, 'aws_session_token')
212

213
214
        credentials['AccessKeyId'] = aws_access_key_id
        credentials['SecretAccessKey'] = secret_access_key
215
216
217
        credentials['Token'] = session_token
    except NoOptionError as e:
        if 'aws_access_key_id' in str(e) or 'aws_secret_access_key' in str(e):
218
219
            logging.debug('aws_access_key_id or aws_secret_access_key not found in %s under named profile [%s]', file_path,
                          awsprofile)
220
221
        if 'aws_session_token' in str(e):
            logging.debug('aws_session_token not found in %s', file_path)
222
223
            credentials['AccessKeyId'] = aws_credentials_configs.get(awsprofile, 'aws_access_key_id')
            credentials['SecretAccessKey'] = aws_credentials_configs.get(awsprofile, 'aws_secret_access_key')
224
    except NoSectionError:
225
        logging.debug('No [%s] section found in config file %s', awsprofile, file_path)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

    return credentials


def url_request_helper(url, unsuccessful_resp, url_error_msg):
    try:
        request_resp = urlopen(url, timeout=1)

        if request_resp.getcode() != 200:
            logging.debug(unsuccessful_resp + ' %s: ResponseCode=%d', url, request_resp.getcode())
            return None

        resp_body = request_resp.read()
        try:
            if type(resp_body) is str:
                resp_dict = json.loads(resp_body)
            else:
                resp_dict = json.loads(resp_body.decode(request_resp.headers.get_content_charset() or 'us-ascii'))

            return resp_dict
        except ValueError as e:
            logging.debug('Error parsing json: %s, returning raw response body: %s' % (e, str(resp_body)))
            return resp_body

    except URLError as e:
        logging.debug('%s %s', url_error_msg, e)
        return None


Max Beckett's avatar
Max Beckett committed
255
def bootstrap_logging(config, log_dir=LOG_DIR):
Ian Patel's avatar
Ian Patel committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    raw_level = config.get(CONFIG_SECTION, 'logging_level')
    levels = {
        'debug': logging.DEBUG,
        'info': logging.INFO,
        'warning': logging.WARNING,
        'error': logging.ERROR,
        'critical': logging.CRITICAL
    }
    level = levels.get(raw_level.lower())
    level_error = False

    if not level:
        # delay logging error about malformed log level until after logging is configured
        level_error = True
        level = logging.INFO

Max Beckett's avatar
Max Beckett committed
272
273
274
275
276
277
278
279
280
281
    max_bytes = config.getint(CONFIG_SECTION, 'logging_max_bytes')
    file_count = config.getint(CONFIG_SECTION, 'logging_file_count')

    handler = RotatingFileHandler(os.path.join(log_dir, LOG_FILE), maxBytes=max_bytes, backupCount=file_count)
    handler.setFormatter(logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s'))

    logger = logging.getLogger()
    logger.setLevel(level)
    logger.addHandler(handler)

Ian Patel's avatar
Ian Patel committed
282
283
284
    if level_error:
        logging.error('Malformed logging level "%s", setting logging level to %s', raw_level, level)

Max Beckett's avatar
Max Beckett committed
285

286
287
288
289
290
291
292
293
294
295
296
def parse_options(options):
    opts = {}
    for o in options.split(','):
        if '=' in o:
            k, v = o.split('=')
            opts[k] = v
        else:
            opts[o] = None
    return opts


297
298
def get_file_safe_mountpoint(mount):
    mountpoint = os.path.abspath(mount.mountpoint).replace(os.sep, '.')
Max Beckett's avatar
Max Beckett committed
299
300
    if mountpoint.startswith('.'):
        mountpoint = mountpoint[1:]
301

302
    opts = parse_options(mount.options)
303
304
305
    if 'port' not in opts:
        # some other localhost nfs mount not running over stunnel
        return None
306
    return mountpoint + '.' + opts['port']
Max Beckett's avatar
Max Beckett committed
307
308
309
310


def get_current_local_nfs_mounts(mount_file='/proc/mounts'):
    """
311
312
    Return a dict of the current NFS mounts for servers running on localhost, keyed by the mountpoint and port as it
    appears in EFS watchdog state files.
Max Beckett's avatar
Max Beckett committed
313
314
315
316
317
318
319
320
321
322
323
    """
    mounts = []

    with open(mount_file) as f:
        for mount in f:
            mounts.append(Mount._make(mount.strip().split()))

    mounts = [m for m in mounts if m.server.startswith('127.0.0.1') and 'nfs' in m.type]

    mount_dict = {}
    for m in mounts:
324
325
326
        safe_mnt = get_file_safe_mountpoint(m)
        if safe_mnt:
            mount_dict[safe_mnt] = m
Max Beckett's avatar
Max Beckett committed
327
328
329
330
331

    return mount_dict


def get_state_files(state_file_dir):
332
333
334
    """
    Return a dict of the absolute path of state files in state_file_dir, keyed by the mountpoint and port portion of the filename.
    """
Max Beckett's avatar
Max Beckett committed
335
336
337
338
    state_files = {}

    if os.path.isdir(state_file_dir):
        for sf in os.listdir(state_file_dir):
339
            if not sf.startswith('fs-') or os.path.isdir(os.path.join(state_file_dir, sf)):
Max Beckett's avatar
Max Beckett committed
340
341
                continue

342
343
            # This translates the state file name "fs-deadbeaf.home.user.mnt.12345"
            # into file-safe mountpoint "home.user.mnt.12345"
Max Beckett's avatar
Max Beckett committed
344
            first_period = sf.find('.')
345
346
347
            mount_point_and_port = sf[first_period + 1:]
            logging.debug('Translating "%s" into mount point and port "%s"', sf, mount_point_and_port)
            state_files[mount_point_and_port] = sf
Max Beckett's avatar
Max Beckett committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

    return state_files


def is_pid_running(pid):
    try:
        os.kill(pid, 0)
        return True
    except OSError:
        return False


def start_tls_tunnel(child_procs, state_file, command):
    # launch the tunnel in a process group so if it has any child processes, they can be killed easily
    logging.info('Starting TLS tunnel: "%s"', ' '.join(command))
363
    tunnel = subprocess.Popen(command, preexec_fn=os.setsid, close_fds=True)
Max Beckett's avatar
Max Beckett committed
364
365
366
367
368
369
370
371
372
373

    if not is_pid_running(tunnel.pid):
        fatal_error('Failed to initialize TLS tunnel for %s' % state_file, 'Failed to start TLS tunnel.')

    logging.info('Started TLS tunnel, pid: %d', tunnel.pid)

    child_procs.append(tunnel)
    return tunnel.pid


374
def clean_up_mount_state(state_file_dir, state_file, pid, is_running, mount_state_dir=None):
Max Beckett's avatar
Max Beckett committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    if is_running:
        process_group = os.getpgid(pid)
        logging.info('Terminating running TLS tunnel - PID: %d, group ID: %s', pid, process_group)
        os.killpg(process_group, SIGTERM)

    if is_pid_running(pid):
        logging.info('TLS tunnel: %d is still running, will retry termination', pid)
    else:
        logging.info('TLS tunnel: %d is no longer running, cleaning up state', pid)
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            state = json.load(f)

        for f in state.get('files', list()):
            logging.debug('Deleting %s', f)
            try:
                os.remove(f)
392
                logging.debug('Deleted %s', f)
Max Beckett's avatar
Max Beckett committed
393
394
395
396
397
398
            except OSError as e:
                if e.errno != errno.ENOENT:
                    raise

        os.remove(state_file_path)

399
400
401
402
403
404
405
406
        if mount_state_dir is not None:
            mount_state_dir_abs_path = os.path.join(state_file_dir, mount_state_dir)
            if os.path.isdir(mount_state_dir_abs_path):
                shutil.rmtree(mount_state_dir_abs_path)
            else:
                logging.debug('Attempt to remove mount state directory %s failed. Directory is not present.',
                              mount_state_dir_abs_path)

Max Beckett's avatar
Max Beckett committed
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

def rewrite_state_file(state, state_file_dir, state_file):
    tmp_state_file = os.path.join(state_file_dir, '~%s' % state_file)
    with open(tmp_state_file, 'w') as f:
        json.dump(state, f)

    os.rename(tmp_state_file, os.path.join(state_file_dir, state_file))


def mark_as_unmounted(state, state_file_dir, state_file, current_time):
    logging.debug('Marking %s as unmounted at %d', state_file, current_time)
    state['unmount_time'] = current_time

    rewrite_state_file(state, state_file_dir, state_file)

    return state


def restart_tls_tunnel(child_procs, state, state_file_dir, state_file):
426
427
428
429
    if 'certificate' in state and not os.path.exists(state['certificate']):
        logging.error('Cannot restart stunnel because self-signed certificate at %s is missing' % state['certificate'])
        return

Max Beckett's avatar
Max Beckett committed
430
431
432
433
434
435
436
    new_tunnel_pid = start_tls_tunnel(child_procs, state_file, state['cmd'])
    state['pid'] = new_tunnel_pid

    logging.debug('Rewriting %s with new pid: %d', state_file, new_tunnel_pid)
    rewrite_state_file(state, state_file_dir, state_file)


437
def check_efs_mounts(config, child_procs, unmount_grace_period_sec, state_file_dir=STATE_FILE_DIR):
Max Beckett's avatar
Max Beckett committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    nfs_mounts = get_current_local_nfs_mounts()
    logging.debug('Current local NFS mounts: %s', list(nfs_mounts.values()))

    state_files = get_state_files(state_file_dir)
    logging.debug('Current state files in "%s": %s', state_file_dir, list(state_files.values()))

    for mount, state_file in state_files.items():
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            try:
                state = json.load(f)
            except ValueError:
                logging.exception('Unable to parse json in %s', state_file_path)
                continue

        is_running = is_pid_running(state['pid'])

        current_time = time.time()
        if 'unmount_time' in state:
            if state['unmount_time'] + unmount_grace_period_sec < current_time:
                logging.info('Unmount grace period expired for %s', state_file)
459
                clean_up_mount_state(state_file_dir, state_file, state['pid'], is_running, state.get('mountStateDir'))
Max Beckett's avatar
Max Beckett committed
460
461
462
463
464
465

        elif mount not in nfs_mounts:
            logging.info('No mount found for "%s"', state_file)
            state = mark_as_unmounted(state, state_file_dir, state_file, current_time)

        else:
466
467
468
            if 'certificate' in state:
                check_certificate(config, state, state_file_dir, state_file)

Max Beckett's avatar
Max Beckett committed
469
470
471
            if is_running:
                logging.debug('TLS tunnel for %s is running', state_file)
            else:
Yuan Gao's avatar
Yuan Gao committed
472
                logging.warning('TLS tunnel for %s is not running', state_file)
Max Beckett's avatar
Max Beckett committed
473
474
475
476
477
478
479
                restart_tls_tunnel(child_procs, state, state_file_dir, state_file)


def check_child_procs(child_procs):
    for proc in child_procs:
        proc.poll()
        if proc.returncode is not None:
Yuan Gao's avatar
Yuan Gao committed
480
            logging.warning('Child TLS tunnel process %d has exited, returncode=%d', proc.pid, proc.returncode)
Max Beckett's avatar
Max Beckett committed
481
482
483
484
485
486
487
488
489
490
491
492
            child_procs.remove(proc)


def parse_arguments(args=None):
    if args is None:
        args = sys.argv

    if '-h' in args[1:] or '--help' in args[1:]:
        sys.stdout.write('Usage: %s [--version] [-h|--help]\n' % args[0])
        sys.exit(0)

    if '--version' in args[1:]:
Ian Patel's avatar
Ian Patel committed
493
        sys.stdout.write('%s Version: %s\n' % (args[0], VERSION))
Max Beckett's avatar
Max Beckett committed
494
495
496
497
        sys.exit(0)


def assert_root():
Ian Patel's avatar
Ian Patel committed
498
    if os.geteuid() != 0:
Max Beckett's avatar
Max Beckett committed
499
500
501
502
503
        sys.stderr.write('only root can run amazon-efs-mount-watchdog\n')
        sys.exit(1)


def read_config(config_file=CONFIG_FILE):
Yuan Gao's avatar
Yuan Gao committed
504
505
506
507
    try:
        p = ConfigParser.SafeConfigParser()
    except AttributeError:
        p = ConfigParser()
Max Beckett's avatar
Max Beckett committed
508
509
510
511
    p.read(config_file)
    return p


512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
def check_certificate(config, state, state_file_dir, state_file, base_path=STATE_FILE_DIR):
    certificate_creation_time = datetime.strptime(state['certificateCreationTime'], CERT_DATETIME_FORMAT)
    certificate_exists = os.path.isfile(state['certificate'])
    # creation instead of NOT_BEFORE datetime is used for refresh of cert because NOT_BEFORE derives from creation datetime
    should_refresh_cert = (get_utc_now() - certificate_creation_time).total_seconds() > REFRESH_SELF_SIGNED_CERT_INTERVAL_SEC

    if certificate_exists and not should_refresh_cert:
        return

    ap_state = state.get('accessPoint')
    if ap_state and not AP_ID_RE.match(ap_state):
        logging.error('Access Point ID "%s" has been changed in the state file to a malformed format' % ap_state)
        return

    if not certificate_exists:
        logging.debug('Certificate (at %s) is missing. Recreating self-signed certificate' % state['certificate'])
    else:
        logging.debug('Refreshing self-signed certificate (at %s)' % state['certificate'])

531
    credentials_source = state.get('awsCredentialsMethod')
532
    updated_certificate_creation_time = recreate_certificate(config, state['mountStateDir'], state['commonName'], state['fsId'],
533
                                                             credentials_source, ap_state, state['region'], base_path=base_path)
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    if updated_certificate_creation_time:
        state['certificateCreationTime'] = updated_certificate_creation_time
        rewrite_state_file(state, state_file_dir, state_file)

        # send SIGHUP to force a reload of the configuration file to trigger the stunnel process to notice the new certificate
        if is_pid_running(state['pid']):
            process_group = os.getpgid(state['pid'])
            logging.info('SIGHUP signal to stunnel. PID: %d, group ID: %s', state['pid'], process_group)
            os.killpg(process_group, SIGHUP)

        if not is_pid_running(state['pid']):
            logging.warning('TLS tunnel is not running for %s', state_file)


def create_required_directory(config, directory):
    mode = 0o750
    try:
        mode_str = config.get(CONFIG_SECTION, 'state_file_dir_mode')
        try:
            mode = int(mode_str, 8)
        except ValueError:
            logging.warning('Bad state_file_dir_mode "%s" in config file "%s"', mode_str, CONFIG_FILE)
    except NoOptionError:
        pass

    try:
        os.makedirs(directory, mode)
        logging.debug('Expected %s not found, recreating asset', directory)
    except OSError as e:
        if errno.EEXIST != e.errno or not os.path.isdir(directory):
            raise


567
568
def recreate_certificate(config, mount_name, common_name, fs_id, credentials_source, ap_id, region,
                         base_path=STATE_FILE_DIR):
569
570
571
572
573
574
575
576
577
578
579
580
    current_time = get_utc_now()
    tls_paths = tls_paths_dictionary(mount_name, base_path)

    certificate_config = os.path.join(tls_paths['mount_dir'], 'config.conf')
    certificate_signing_request = os.path.join(tls_paths['mount_dir'], 'request.csr')
    certificate = os.path.join(tls_paths['mount_dir'], 'certificate.pem')

    ca_dirs_check(config, tls_paths['database_dir'], tls_paths['certs_dir'])
    ca_supporting_files_check(tls_paths['index'], tls_paths['index_attr'], tls_paths['serial'], tls_paths['rand'])

    private_key = check_and_create_private_key(base_path)

581
    if credentials_source:
582
583
584
585
        public_key = os.path.join(tls_paths['mount_dir'], 'publicKey.pem')
        create_public_key(private_key, public_key)

    config_body = create_ca_conf(certificate_config, common_name, tls_paths['mount_dir'], private_key, current_time, region,
586
                                 fs_id, credentials_source, ap_id=ap_id)
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657

    if not config_body:
        logging.error('Cannot recreate self-signed certificate')
        return None

    create_certificate_signing_request(certificate_config, private_key, certificate_signing_request)

    not_before = get_certificate_timestamp(current_time, minutes=-NOT_BEFORE_MINS)
    not_after = get_certificate_timestamp(current_time, hours=NOT_AFTER_HOURS)

    cmd = 'openssl ca -startdate %s -enddate %s -selfsign -batch -notext -config %s -in %s -out %s' % \
          (not_before, not_after, certificate_config, certificate_signing_request, certificate)
    subprocess_call(cmd, 'Failed to create self-signed client-side certificate')
    return current_time.strftime(CERT_DATETIME_FORMAT)


def get_private_key_path():
    """Wrapped for mocking purposes in unit tests"""
    return PRIVATE_KEY_FILE


def check_and_create_private_key(base_path=STATE_FILE_DIR):
    # Creating RSA private keys is slow, so we will create one private key and allow mounts to share it.
    # This means, however, that we have to include a locking mechanism to ensure that the private key is
    # atomically created, as mounts occurring in parallel may try to create the key simultaneously.
    # The key should have been created during mounting, but the watchdog will recreate the private key if
    # it is missing.
    key = get_private_key_path()

    @contextmanager
    def open_lock_file():
        lock_file = os.path.join(base_path, 'efs-utils-lock')
        f = os.open(lock_file, os.O_CREAT | os.O_DSYNC | os.O_EXCL | os.O_RDWR)
        try:
            lock_file_contents = 'PID: %s' % os.getpid()
            os.write(f, lock_file_contents.encode('utf-8'))
            yield f
        finally:
            os.close(f)
            os.remove(lock_file)

    def do_with_lock(function):
        while True:
            try:
                with open_lock_file():
                    return function()
            except OSError as e:
                if e.errno == errno.EEXIST:
                    logging.info('Failed to take out private key creation lock, sleeping 50 ms')
                    time.sleep(0.05)
                else:
                    raise

    def generate_key():
        if os.path.isfile(key):
            return

        cmd = 'openssl genpkey -algorithm RSA -out %s -pkeyopt rsa_keygen_bits:3072' % key
        subprocess_call(cmd, 'Failed to create private key')
        read_only_mode = 0o400
        os.chmod(key, read_only_mode)

    do_with_lock(generate_key)
    return key


def create_certificate_signing_request(config_path, key_path, csr_path):
    cmd = 'openssl req -new -config %s -key %s -out %s' % (config_path, key_path, csr_path)
    subprocess_call(cmd, 'Failed to create certificate signing request (csr)')


658
659
def create_ca_conf(config_path, common_name, directory, private_key, date, region, fs_id, credentials_source,
                   ap_id=None):
660
661
    """Populate ca/req configuration file with fresh configurations at every mount since SigV4 signature can change"""
    public_key_path = os.path.join(directory, 'publicKey.pem')
662
    security_credentials = get_aws_security_credentials(credentials_source) if credentials_source else ''
663

664
665
    if credentials_source and security_credentials is None:
        logging.error('Failed to retrieve AWS security credentials using lookup method: %s', credentials_source)
666
667
        return None

668
669
670
671
    ca_extension_body = ca_extension_builder(ap_id, security_credentials, fs_id)
    efs_client_auth_body = efs_client_auth_builder(public_key_path, security_credentials['AccessKeyId'],
                                                   security_credentials['SecretAccessKey'], date, region, fs_id,
                                                   security_credentials['Token']) if credentials_source else ''
672

673
674
    if credentials_source and not efs_client_auth_body:
        logging.error('Failed to create AWS SigV4 signature section for OpenSSL config. Public Key path: %s', public_key_path)
675
676
677
678
679
680
681
682
683
684
        return None

    full_config_body = CA_CONFIG_BODY % (directory, private_key, common_name, ca_extension_body, efs_client_auth_body)

    with open(config_path, 'w') as f:
        f.write(full_config_body)

    return full_config_body


685
def ca_extension_builder(ap_id, security_credentials, fs_id):
686
687
688
    ca_extension_str = '[ v3_ca ]\nsubjectKeyIdentifier = hash'
    if ap_id:
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.1 = ASN1:UTF8String:' + ap_id
689
    if security_credentials:
690
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.2 = ASN1:SEQUENCE:efs_client_auth'
691
692

    ca_extension_str += '\n1.3.6.1.4.1.4843.7.3 = ASN1:UTF8String:' + fs_id
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924

    return ca_extension_str


def efs_client_auth_builder(public_key_path, access_key_id, secret_access_key, date, region, fs_id, session_token=None):
    public_key_hash = get_public_key_sha1(public_key_path)

    if not public_key_hash:
        return None

    canonical_request = create_canonical_request(public_key_hash, date, access_key_id, region, fs_id, session_token)
    string_to_sign = create_string_to_sign(canonical_request, date, region)
    signature = calculate_signature(string_to_sign, date, secret_access_key, region)
    efs_client_auth_str = '[ efs_client_auth ]'
    efs_client_auth_str += '\naccessKeyId = UTF8String:' + access_key_id
    efs_client_auth_str += '\nsignature = OCTETSTRING:' + signature
    efs_client_auth_str += '\nsigv4DateTime = UTCTIME:' + date.strftime(CERT_DATETIME_FORMAT)

    if session_token:
        efs_client_auth_str += '\nsessionToken = EXPLICIT:0,UTF8String:' + session_token

    return efs_client_auth_str


def create_public_key(private_key, public_key):
    cmd = 'openssl rsa -in %s -outform PEM -pubout -out %s' % (private_key, public_key)
    subprocess_call(cmd, 'Failed to create public key')


def subprocess_call(cmd, error_message):
    """Helper method to run shell openssl command and to handle response error messages"""
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
    (output, err) = process.communicate()
    rc = process.poll()
    if rc != 0:
        logging.debug('%s. Command %s failed, rc=%s, stdout="%s", stderr="%s"', error_message, cmd, rc, output, err)
    else:
        return output, err


def ca_dirs_check(config, database_dir, certs_dir):
    """Check if mount's database and certs directories exist and if not, create directories (also create all intermediate
    directories if they don't exist)."""
    if not os.path.exists(database_dir):
        create_required_directory(config, database_dir)
    if not os.path.exists(certs_dir):
        create_required_directory(config, certs_dir)


def ca_supporting_files_check(index_path, index_attr_path, serial_path, rand_path):
    """Create all supporting openssl ca and req files if they're not present in their respective directories"""
    def _recreate_file_warning(path):
        logging.warning('Expected %s not found, recreating file', path)

    if not os.path.isfile(index_path):
        open(index_path, 'w').close()
        _recreate_file_warning(index_path)
    if not os.path.isfile(index_attr_path):
        with open(index_attr_path, 'w+') as f:
            f.write('unique_subject = no')
        _recreate_file_warning(index_attr_path)
    if not os.path.isfile(serial_path):
        with open(serial_path, 'w+') as f:
            f.write('00')
        _recreate_file_warning(serial_path)
    if not os.path.isfile(rand_path):
        open(rand_path, 'w').close()
        _recreate_file_warning(rand_path)


def tls_paths_dictionary(mount_name, base_path=STATE_FILE_DIR):
    tls_dict = {
        'mount_dir': os.path.join(base_path, mount_name),
        'database_dir': os.path.join(base_path, mount_name, 'database'),
        'certs_dir': os.path.join(base_path, mount_name, 'certs'),
        'index': os.path.join(base_path, mount_name, 'database/index.txt'),
        'index_attr': os.path.join(base_path, mount_name, 'database/index.txt.attr'),
        'serial': os.path.join(base_path, mount_name, 'database/serial'),
        'rand': os.path.join(base_path, mount_name, 'database/.rand')
    }

    return tls_dict


def get_public_key_sha1(public_key):
    # truncating public key to remove the header and footer '-----(BEGIN|END) PUBLIC KEY-----'
    with open(public_key, 'r') as f:
        lines = f.readlines()
        lines = lines[1:-1]

    key = ''.join(lines)
    key = bytearray(base64.b64decode(key))

    # Parse the public key to pull out the actual key material by looking for the key BIT STRING
    # Example:
    #     0:d=0  hl=4 l= 418 cons: SEQUENCE
    #     4:d=1  hl=2 l=  13 cons: SEQUENCE
    #     6:d=2  hl=2 l=   9 prim: OBJECT            :rsaEncryption
    #    17:d=2  hl=2 l=   0 prim: NULL
    #    19:d=1  hl=4 l= 399 prim: BIT STRING
    cmd = 'openssl asn1parse -inform PEM -in %s' % public_key
    output, err = subprocess_call(cmd, 'Unable to ASN1 parse public key file, %s, correctly' % public_key)

    key_line = ''
    for line in output.splitlines():
        if 'BIT STRING' in line.decode('utf-8'):
            key_line = line.decode('utf-8')

    if not key_line:
        logging.error('Public key file, %s, is incorrectly formatted', public_key)
        return None

    key_line = key_line.replace(' ', '')

    # DER encoding TLV (Tag, Length, Value)
    # - the first octet (byte) is the tag (type)
    # - the next octets are the length - "definite form"
    #   - the first octet always has the high order bit (8) set to 1
    #   - the remaining 127 bits are used to encode the number of octets that follow
    #   - the following octets encode, as big-endian, the length (which may be 0) as a number of octets
    # - the remaining octets are the "value" aka content
    #
    # For a BIT STRING, the first octet of the value is used to signify the number of unused bits that exist in the last
    # content byte. Note that this is explicitly excluded from the SubjectKeyIdentifier hash, per
    # https://tools.ietf.org/html/rfc5280#section-4.2.1.2
    #
    # Example:
    #   0382018f00...<subjectPublicKey>
    #   - 03 - BIT STRING tag
    #   - 82 - 2 length octets to follow (ignore high order bit)
    #   - 018f - length of 399
    #   - 00 - no unused bits in the last content byte
    offset = int(key_line.split(':')[0])
    key = key[offset:]

    num_length_octets = key[1] & 0b01111111

    # Exclude the tag (1), length (1 + num_length_octets), and number of unused bits (1)
    offset = 1 + 1 + num_length_octets + 1
    key = key[offset:]

    sha1 = hashlib.sha1()
    sha1.update(key)

    return sha1.hexdigest()


def create_canonical_request(public_key_hash, date, access_key, region, fs_id, session_token=None):
    """
    Create a Canonical Request - https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
    """
    formatted_datetime = date.strftime(SIGV4_DATETIME_FORMAT)
    credential = quote_plus(access_key + '/' + get_credential_scope(date, region))

    request = HTTP_REQUEST_METHOD + '\n'
    request += CANONICAL_URI + '\n'
    request += create_canonical_query_string(public_key_hash, credential, formatted_datetime, session_token) + '\n'
    request += CANONICAL_HEADERS % fs_id + '\n'
    request += SIGNED_HEADERS + '\n'

    sha256 = hashlib.sha256()
    sha256.update(REQUEST_PAYLOAD.encode())
    request += sha256.hexdigest()

    return request


def create_canonical_query_string(public_key_hash, credential, formatted_datetime, session_token=None):
    canonical_query_params = {
        'Action': 'Connect',
        # Public key hash is included in canonical request to tie the signature to a specific key pair to avoid replay attacks
        'PublicKeyHash': quote_plus(public_key_hash),
        'X-Amz-Algorithm': ALGORITHM,
        'X-Amz-Credential': credential,
        'X-Amz-Date': quote_plus(formatted_datetime),
        'X-Amz-Expires': 86400,
        'X-Amz-SignedHeaders': SIGNED_HEADERS,
    }

    if session_token:
        canonical_query_params['X-Amz-Security-Token'] = quote_plus(session_token)

    # Cannot use urllib.urlencode because it replaces the %s's
    return '&'.join(['%s=%s' % (k, v) for k, v in sorted(canonical_query_params.items())])


def create_string_to_sign(canonical_request, date, region):
    """
    Create a String to Sign - https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
    """
    string_to_sign = ALGORITHM + '\n'
    string_to_sign += date.strftime(SIGV4_DATETIME_FORMAT) + '\n'
    string_to_sign += get_credential_scope(date, region) + '\n'

    sha256 = hashlib.sha256()
    sha256.update(canonical_request.encode())
    string_to_sign += sha256.hexdigest()

    return string_to_sign


def calculate_signature(string_to_sign, date, secret_access_key, region):
    """
    Calculate the Signature - https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html
    """
    def _sign(key, msg):
        return hmac.new(key, msg.encode('utf-8'), hashlib.sha256)

    key_date = _sign(('AWS4' + secret_access_key).encode('utf-8'), date.strftime(DATE_ONLY_FORMAT)).digest()
    add_region = _sign(key_date, region).digest()
    add_service = _sign(add_region, SERVICE).digest()
    signing_key = _sign(add_service, 'aws4_request').digest()

    return _sign(signing_key, string_to_sign).hexdigest()


def get_credential_scope(date, region):
    return '/'.join([date.strftime(DATE_ONLY_FORMAT), region, SERVICE, AWS4_REQUEST])


def get_certificate_timestamp(current_time, **kwargs):
    updated_time = current_time + timedelta(**kwargs)
    return updated_time.strftime(CERT_DATETIME_FORMAT)


def get_utc_now():
    """
    Wrapped for patching purposes in unit tests
    """
    return datetime.utcnow()


Max Beckett's avatar
Max Beckett committed
925
926
927
928
929
930
931
932
933
934
935
936
937
938
def main():
    parse_arguments()
    assert_root()

    config = read_config()
    bootstrap_logging(config)

    child_procs = []

    if config.getboolean(CONFIG_SECTION, 'enabled'):
        poll_interval_sec = config.getint(CONFIG_SECTION, 'poll_interval_sec')
        unmount_grace_period_sec = config.getint(CONFIG_SECTION, 'unmount_grace_period_sec')

        while True:
939
            check_efs_mounts(config, child_procs, unmount_grace_period_sec)
Max Beckett's avatar
Max Beckett committed
940
941
942
943
944
945
946
947
948
            check_child_procs(child_procs)

            time.sleep(poll_interval_sec)
    else:
        logging.info('amazon-efs-mount-watchdog is not enabled')


if '__main__' == __name__:
    main()