__init__.py 43.2 KB
Newer Older
Max Beckett's avatar
Max Beckett committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
#
# Copyright 2017-2018 Amazon.com, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the MIT License. See the LICENSE accompanying this file
# for the specific language governing permissions and limitations under
# the License.
#

10
import base64
Max Beckett's avatar
Max Beckett committed
11
import errno
12
13
import hashlib
import hmac
Max Beckett's avatar
Max Beckett committed
14
15
16
17
import json
import logging
import logging.handlers
import os
18
19
20
import pwd
import re
import shutil
Max Beckett's avatar
Max Beckett committed
21
22
23
24
25
import subprocess
import sys
import time

from collections import namedtuple
26
27
from contextlib import contextmanager
from datetime import datetime, timedelta
Max Beckett's avatar
Max Beckett committed
28
from logging.handlers import RotatingFileHandler
29
from signal import SIGTERM, SIGHUP
Max Beckett's avatar
Max Beckett committed
30
31
32

try:
    import ConfigParser
33
    from ConfigParser import NoOptionError, NoSectionError
Max Beckett's avatar
Max Beckett committed
34
except ImportError:
35
    from configparser import ConfigParser, NoOptionError, NoSectionError
Max Beckett's avatar
Max Beckett committed
36

37
38
39
40
41
42
try:
    from urllib.parse import quote_plus
except ImportError:
    from urllib import quote_plus

try:
43
    from urllib2 import build_opener, urlopen, URLError, HTTPError, HTTPHandler, Request
44
    from urllib import urlencode
45
except ImportError:
46
    from urllib.error import HTTPError, URLError
47
48
    from urllib.request import urlopen, Request
    from urllib.parse import urlencode
49

50
51

VERSION = '1.28.2'
52
SERVICE = 'elasticfilesystem'
Max Beckett's avatar
Max Beckett committed
53
54
55

CONFIG_FILE = '/etc/amazon/efs/efs-utils.conf'
CONFIG_SECTION = 'mount-watchdog'
56
57
CLIENT_INFO_SECTION = 'client-info'
CLIENT_SOURCE_STR_LEN_LIMIT = 100
Max Beckett's avatar
Max Beckett committed
58
59
60
61
62
63

LOG_DIR = '/var/log/amazon/efs'
LOG_FILE = 'mount-watchdog.log'

STATE_FILE_DIR = '/var/run/efs'

64
PRIVATE_KEY_FILE = '/etc/amazon/efs/privateKey.pem'
65
DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN = 60
66
67
68
69
70
71
NOT_BEFORE_MINS = 15
NOT_AFTER_HOURS = 3
DATE_ONLY_FORMAT = '%Y%m%d'
SIGV4_DATETIME_FORMAT = '%Y%m%dT%H%M%SZ'
CERT_DATETIME_FORMAT = '%y%m%d%H%M%SZ'

72
73
74
75
76
AWS_CREDENTIALS_FILES = {
    'credentials': os.path.expanduser(os.path.join('~' + pwd.getpwuid(os.getuid()).pw_name, '.aws', 'credentials')),
    'config': os.path.expanduser(os.path.join('~' + pwd.getpwuid(os.getuid()).pw_name, '.aws', 'config')),
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
CA_CONFIG_BODY = """dir = %s
RANDFILE = $dir/database/.rand

[ ca ]
default_ca = local_ca

[ local_ca ]
database = $dir/database/index.txt
serial = $dir/database/serial
private_key = %s
cert = $dir/certificate.pem
new_certs_dir = $dir/certs
default_md = sha256
preserve = no
policy = efsPolicy
x509_extensions = v3_ca

[ efsPolicy ]
CN = supplied

[ req ]
prompt = no
distinguished_name = req_distinguished_name

[ req_distinguished_name ]
CN = %s

%s

106
107
%s

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
%s
"""

# SigV4 Auth
ALGORITHM = 'AWS4-HMAC-SHA256'
AWS4_REQUEST = 'aws4_request'

HTTP_REQUEST_METHOD = 'GET'
CANONICAL_URI = '/'
CANONICAL_HEADERS_DICT = {
    'host': '%s'
}
CANONICAL_HEADERS = '\n'.join(['%s:%s' % (k, v) for k, v in sorted(CANONICAL_HEADERS_DICT.items())])
SIGNED_HEADERS = ';'.join(CANONICAL_HEADERS_DICT.keys())
REQUEST_PAYLOAD = ''

AP_ID_RE = re.compile('^fsap-[0-9a-f]{17}$')

126
ECS_TASK_METADATA_API = 'http://169.254.170.2'
127
STS_ENDPOINT_URL = 'https://sts.amazonaws.com/'
128
INSTANCE_IAM_URL = 'http://169.254.169.254/latest/meta-data/iam/security-credentials/'
129
INSTANCE_METADATA_TOKEN_URL = 'http://169.254.169.254/latest/api/token'
130
SECURITY_CREDS_ECS_URI_HELP_URL = 'https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html'
131
SECURITY_CREDS_WEBIDENTITY_HELP_URL = 'https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html'
132
133
SECURITY_CREDS_IAM_ROLE_HELP_URL = 'https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html'

Max Beckett's avatar
Max Beckett committed
134
135
136
137
138
139
140
141
142
143
144
145
Mount = namedtuple('Mount', ['server', 'mountpoint', 'type', 'options', 'freq', 'passno'])


def fatal_error(user_message, log_message=None):
    if log_message is None:
        log_message = user_message

    sys.stderr.write('%s\n' % user_message)
    logging.error(log_message)
    sys.exit(1)


146
def get_aws_security_credentials(credentials_source):
147
148
149
150
151
    """
    Lookup AWS security credentials (access key ID and secret access key). Adapted credentials provider chain from:
    https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html and
    https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html
    """
152
153
154
155
156
157
158
159
    method, value = credentials_source.split(':', 1)

    if method == 'credentials':
        return get_aws_security_credentials_from_file('credentials', value)
    elif method == 'config':
        return get_aws_security_credentials_from_file('config', value)
    elif method == 'ecs':
        return get_aws_security_credentials_from_ecs(value)
160
161
    elif method == 'webidentity':
        return get_aws_security_credentials_from_webidentity(*(value.split(',')))
162
163
164
165
166
167
    elif method == 'metadata':
        return get_aws_security_credentials_from_instance_metadata()
    else:
        logging.error('Improper credentials source string "%s" found from mount state file', credentials_source)
        return None

168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def get_aws_ec2_metadata_token():
    try:
        opener = build_opener(HTTPHandler)
        request = Request(INSTANCE_METADATA_TOKEN_URL)
        request.add_header('X-aws-ec2-metadata-token-ttl-seconds', 21600)
        request.get_method = lambda: 'PUT'
        res = opener.open(request)
        return res.read()
    except NameError:
        headers = {'X-aws-ec2-metadata-token-ttl-seconds': 21600}
        req = Request(INSTANCE_METADATA_TOKEN_URL, headers=headers, method='PUT')
        res = urlopen(req)
        return res.read()


184
185
186
187
188
def get_aws_security_credentials_from_file(file_name, awsprofile):
    # attempt to lookup AWS security credentials in AWS credentials file (~/.aws/credentials) and configs file (~/.aws/config)
    file_path = AWS_CREDENTIALS_FILES.get(file_name)
    if file_path and os.path.exists(file_path):
        credentials = credentials_file_helper(file_path, awsprofile)
189
190
191
        if credentials['AccessKeyId']:
            return credentials

192
193
    logging.error('AWS security credentials not found in %s under named profile [%s]', file_path, awsprofile)
    return None
194
195


196
def get_aws_security_credentials_from_ecs(uri):
197
    # through ECS security credentials uri found in AWS_CONTAINER_CREDENTIALS_RELATIVE_URI environment variable
198
199
200
    dict_keys = ['AccessKeyId', 'SecretAccessKey', 'Token']
    ecs_uri = ECS_TASK_METADATA_API + uri
    ecs_unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % ecs_uri
Yuan Gao's avatar
Yuan Gao committed
201
202
    ecs_url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.' % \
                        (ecs_uri, SECURITY_CREDS_ECS_URI_HELP_URL)
203
204
205
206
207
208
    ecs_security_dict = url_request_helper(ecs_uri, ecs_unsuccessful_resp, ecs_url_error_msg)

    if ecs_security_dict and all(k in ecs_security_dict for k in dict_keys):
        return ecs_security_dict

    return None
209
210


211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def get_aws_security_credentials_from_webidentity(role_arn, token_file):
    try:
        with open(token_file, 'r') as f:
            token = f.read()
    except Exception as e:
        logging.error('Error reading token file %s: %s', token_file, e)
        return None

    webidentity_url = STS_ENDPOINT_URL + '?' + urlencode({
        'Version': '2011-06-15',
        'Action': 'AssumeRoleWithWebIdentity',
        'RoleArn': role_arn,
        'RoleSessionName': 'efs-mount-helper',
        'WebIdentityToken': token
    })

    unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % STS_ENDPOINT_URL
    url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.' % \
                    (STS_ENDPOINT_URL, SECURITY_CREDS_WEBIDENTITY_HELP_URL)
    resp = url_request_helper(webidentity_url, unsuccessful_resp, url_error_msg, headers={'Accept': 'application/json'})

    if resp:
        creds = resp \
                .get('AssumeRoleWithWebIdentityResponse', {}) \
                .get('AssumeRoleWithWebIdentityResult', {}) \
                .get('Credentials', {})
        if all(k in creds for k in ['AccessKeyId', 'SecretAccessKey', 'SessionToken']):
            return {
                'AccessKeyId': creds['AccessKeyId'],
                'SecretAccessKey': creds['SecretAccessKey'],
                'Token': creds['SessionToken']
            }

    return None


247
def get_aws_security_credentials_from_instance_metadata():
248
    # through IAM role name security credentials lookup uri (after lookup for IAM role name attached to instance)
249
    dict_keys = ['AccessKeyId', 'SecretAccessKey', 'Token']
250
    iam_role_unsuccessful_resp = 'Unsuccessful retrieval of IAM role name at %s.' % INSTANCE_IAM_URL
Yuan Gao's avatar
Yuan Gao committed
251
252
    iam_role_url_error_msg = 'Unable to reach %s to retrieve IAM role name. See %s for more info.' % \
                             (INSTANCE_IAM_URL, SECURITY_CREDS_IAM_ROLE_HELP_URL)
253
254
    iam_role_name = url_request_helper(INSTANCE_IAM_URL, iam_role_unsuccessful_resp,
                                       iam_role_url_error_msg, retry_with_new_header_token=True)
255
    if iam_role_name:
Yuan Gao's avatar
Yuan Gao committed
256
        security_creds_lookup_url = INSTANCE_IAM_URL + iam_role_name
257
        unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % security_creds_lookup_url
Yuan Gao's avatar
Yuan Gao committed
258
259
        url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.' % \
                        (security_creds_lookup_url, SECURITY_CREDS_IAM_ROLE_HELP_URL)
260
261
        iam_security_dict = url_request_helper(security_creds_lookup_url, unsuccessful_resp,
                                               url_error_msg, retry_with_new_header_token=True)
262

263
        if iam_security_dict and all(k in iam_security_dict for k in dict_keys):
264
265
            return iam_security_dict

266
    return None
267
268


269
def credentials_file_helper(file_path, awsprofile):
270
271
272
273
    aws_credentials_configs = read_config(file_path)
    credentials = {'AccessKeyId': None, 'SecretAccessKey': None, 'Token': None}

    try:
274
275
276
        aws_access_key_id = aws_credentials_configs.get(awsprofile, 'aws_access_key_id')
        secret_access_key = aws_credentials_configs.get(awsprofile, 'aws_secret_access_key')
        session_token = aws_credentials_configs.get(awsprofile, 'aws_session_token')
277

278
279
        credentials['AccessKeyId'] = aws_access_key_id
        credentials['SecretAccessKey'] = secret_access_key
280
281
282
        credentials['Token'] = session_token
    except NoOptionError as e:
        if 'aws_access_key_id' in str(e) or 'aws_secret_access_key' in str(e):
283
284
            logging.debug('aws_access_key_id or aws_secret_access_key not found in %s under named profile [%s]', file_path,
                          awsprofile)
285
286
        if 'aws_session_token' in str(e):
            logging.debug('aws_session_token not found in %s', file_path)
287
288
            credentials['AccessKeyId'] = aws_credentials_configs.get(awsprofile, 'aws_access_key_id')
            credentials['SecretAccessKey'] = aws_credentials_configs.get(awsprofile, 'aws_secret_access_key')
289
    except NoSectionError:
290
        logging.debug('No [%s] section found in config file %s', awsprofile, file_path)
291
292
293
294

    return credentials


295
def url_request_helper(url, unsuccessful_resp, url_error_msg, headers={}, retry_with_new_header_token=False):
296
    try:
297
298
299
300
        req = Request(url)
        for k, v in headers.items():
            req.add_header(k, v)
        request_resp = urlopen(req, timeout=1)
301

302
303
304
305
306
307
308
309
310
311
        return get_resp_obj(request_resp, url, unsuccessful_resp)
    except HTTPError as e:
        # For instance enable with IMDSv2, Unauthorized 401 error will be thrown,
        # to retrieve metadata, the header should embeded with metadata token
        if e.code == 401 and retry_with_new_header_token:
            token = get_aws_ec2_metadata_token()
            req.add_header('X-aws-ec2-metadata-token', token)
            request_resp = urlopen(req, timeout=1)
            return get_resp_obj(request_resp, url, unsuccessful_resp)
        err_msg = 'Unable to reach the url at %s: status=%d, reason is %s' % (url, e.code, e.reason)
312
    except URLError as e:
313
314
315
316
317
318
319
320
321
322
        err_msg = 'Unable to reach the url at %s, reason is %s' % (url, e.reason)

    if err_msg:
        logging.debug('%s %s', url_error_msg, err_msg)
    return None


def get_resp_obj(request_resp, url, unsuccessful_resp):
    if request_resp.getcode() != 200:
        logging.debug(unsuccessful_resp + ' %s: ResponseCode=%d', url, request_resp.getcode())
323
324
        return None

325
326
327
328
329
330
331
332
333
334
335
336
337
    resp_body = request_resp.read()
    resp_body_type = type(resp_body)
    try:
        if resp_body_type is str:
            resp_dict = json.loads(resp_body)
        else:
            resp_dict = json.loads(resp_body.decode(request_resp.headers.get_content_charset() or 'us-ascii'))

        return resp_dict
    except ValueError as e:
        logging.info('ValueError parsing "%s" into json: %s. Returning response body.' % (str(resp_body), e))
        return resp_body if resp_body_type is str else resp_body.decode('utf-8')

338

Max Beckett's avatar
Max Beckett committed
339
def bootstrap_logging(config, log_dir=LOG_DIR):
Ian Patel's avatar
Ian Patel committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    raw_level = config.get(CONFIG_SECTION, 'logging_level')
    levels = {
        'debug': logging.DEBUG,
        'info': logging.INFO,
        'warning': logging.WARNING,
        'error': logging.ERROR,
        'critical': logging.CRITICAL
    }
    level = levels.get(raw_level.lower())
    level_error = False

    if not level:
        # delay logging error about malformed log level until after logging is configured
        level_error = True
        level = logging.INFO

Max Beckett's avatar
Max Beckett committed
356
357
358
359
360
361
362
363
364
365
    max_bytes = config.getint(CONFIG_SECTION, 'logging_max_bytes')
    file_count = config.getint(CONFIG_SECTION, 'logging_file_count')

    handler = RotatingFileHandler(os.path.join(log_dir, LOG_FILE), maxBytes=max_bytes, backupCount=file_count)
    handler.setFormatter(logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s'))

    logger = logging.getLogger()
    logger.setLevel(level)
    logger.addHandler(handler)

Ian Patel's avatar
Ian Patel committed
366
367
368
    if level_error:
        logging.error('Malformed logging level "%s", setting logging level to %s', raw_level, level)

Max Beckett's avatar
Max Beckett committed
369

370
371
372
373
374
375
376
377
378
379
380
def parse_options(options):
    opts = {}
    for o in options.split(','):
        if '=' in o:
            k, v = o.split('=')
            opts[k] = v
        else:
            opts[o] = None
    return opts


381
382
def get_file_safe_mountpoint(mount):
    mountpoint = os.path.abspath(mount.mountpoint).replace(os.sep, '.')
Max Beckett's avatar
Max Beckett committed
383
384
    if mountpoint.startswith('.'):
        mountpoint = mountpoint[1:]
385

386
    opts = parse_options(mount.options)
387
388
389
    if 'port' not in opts:
        # some other localhost nfs mount not running over stunnel
        return None
390
    return mountpoint + '.' + opts['port']
Max Beckett's avatar
Max Beckett committed
391
392
393
394


def get_current_local_nfs_mounts(mount_file='/proc/mounts'):
    """
395
396
    Return a dict of the current NFS mounts for servers running on localhost, keyed by the mountpoint and port as it
    appears in EFS watchdog state files.
Max Beckett's avatar
Max Beckett committed
397
398
399
400
401
402
403
404
405
406
407
    """
    mounts = []

    with open(mount_file) as f:
        for mount in f:
            mounts.append(Mount._make(mount.strip().split()))

    mounts = [m for m in mounts if m.server.startswith('127.0.0.1') and 'nfs' in m.type]

    mount_dict = {}
    for m in mounts:
408
409
410
        safe_mnt = get_file_safe_mountpoint(m)
        if safe_mnt:
            mount_dict[safe_mnt] = m
Max Beckett's avatar
Max Beckett committed
411
412
413
414
415

    return mount_dict


def get_state_files(state_file_dir):
416
417
418
    """
    Return a dict of the absolute path of state files in state_file_dir, keyed by the mountpoint and port portion of the filename.
    """
Max Beckett's avatar
Max Beckett committed
419
420
421
422
    state_files = {}

    if os.path.isdir(state_file_dir):
        for sf in os.listdir(state_file_dir):
423
            if not sf.startswith('fs-') or os.path.isdir(os.path.join(state_file_dir, sf)):
Max Beckett's avatar
Max Beckett committed
424
425
                continue

426
427
            # This translates the state file name "fs-deadbeaf.home.user.mnt.12345"
            # into file-safe mountpoint "home.user.mnt.12345"
Max Beckett's avatar
Max Beckett committed
428
            first_period = sf.find('.')
429
430
431
            mount_point_and_port = sf[first_period + 1:]
            logging.debug('Translating "%s" into mount point and port "%s"', sf, mount_point_and_port)
            state_files[mount_point_and_port] = sf
Max Beckett's avatar
Max Beckett committed
432
433
434
435
436

    return state_files


def is_pid_running(pid):
437
438
    if not pid:
        return False
Max Beckett's avatar
Max Beckett committed
439
440
441
442
443
444
445
446
447
448
    try:
        os.kill(pid, 0)
        return True
    except OSError:
        return False


def start_tls_tunnel(child_procs, state_file, command):
    # launch the tunnel in a process group so if it has any child processes, they can be killed easily
    logging.info('Starting TLS tunnel: "%s"', ' '.join(command))
449
    tunnel = subprocess.Popen(command, preexec_fn=os.setsid, close_fds=True)
Max Beckett's avatar
Max Beckett committed
450
451
452
453
454
455
456
457
458
459

    if not is_pid_running(tunnel.pid):
        fatal_error('Failed to initialize TLS tunnel for %s' % state_file, 'Failed to start TLS tunnel.')

    logging.info('Started TLS tunnel, pid: %d', tunnel.pid)

    child_procs.append(tunnel)
    return tunnel.pid


460
def clean_up_mount_state(state_file_dir, state_file, pid, is_running, mount_state_dir=None):
Max Beckett's avatar
Max Beckett committed
461
462
463
464
465
466
467
468
    if is_running:
        process_group = os.getpgid(pid)
        logging.info('Terminating running TLS tunnel - PID: %d, group ID: %s', pid, process_group)
        os.killpg(process_group, SIGTERM)

    if is_pid_running(pid):
        logging.info('TLS tunnel: %d is still running, will retry termination', pid)
    else:
469
470
471
472
        if not pid:
            logging.info('TLS tunnel has been killed, cleaning up state')
        else:
            logging.info('TLS tunnel: %d is no longer running, cleaning up state', pid)
Max Beckett's avatar
Max Beckett committed
473
474
475
476
477
478
479
480
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            state = json.load(f)

        for f in state.get('files', list()):
            logging.debug('Deleting %s', f)
            try:
                os.remove(f)
481
                logging.debug('Deleted %s', f)
Max Beckett's avatar
Max Beckett committed
482
483
484
485
486
487
            except OSError as e:
                if e.errno != errno.ENOENT:
                    raise

        os.remove(state_file_path)

488
489
490
491
492
493
494
495
        if mount_state_dir is not None:
            mount_state_dir_abs_path = os.path.join(state_file_dir, mount_state_dir)
            if os.path.isdir(mount_state_dir_abs_path):
                shutil.rmtree(mount_state_dir_abs_path)
            else:
                logging.debug('Attempt to remove mount state directory %s failed. Directory is not present.',
                              mount_state_dir_abs_path)

Max Beckett's avatar
Max Beckett committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514

def rewrite_state_file(state, state_file_dir, state_file):
    tmp_state_file = os.path.join(state_file_dir, '~%s' % state_file)
    with open(tmp_state_file, 'w') as f:
        json.dump(state, f)

    os.rename(tmp_state_file, os.path.join(state_file_dir, state_file))


def mark_as_unmounted(state, state_file_dir, state_file, current_time):
    logging.debug('Marking %s as unmounted at %d', state_file, current_time)
    state['unmount_time'] = current_time

    rewrite_state_file(state, state_file_dir, state_file)

    return state


def restart_tls_tunnel(child_procs, state, state_file_dir, state_file):
515
516
517
518
    if 'certificate' in state and not os.path.exists(state['certificate']):
        logging.error('Cannot restart stunnel because self-signed certificate at %s is missing' % state['certificate'])
        return

Max Beckett's avatar
Max Beckett committed
519
520
521
522
523
524
525
    new_tunnel_pid = start_tls_tunnel(child_procs, state_file, state['cmd'])
    state['pid'] = new_tunnel_pid

    logging.debug('Rewriting %s with new pid: %d', state_file, new_tunnel_pid)
    rewrite_state_file(state, state_file_dir, state_file)


526
def check_efs_mounts(config, child_procs, unmount_grace_period_sec, state_file_dir=STATE_FILE_DIR):
Max Beckett's avatar
Max Beckett committed
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    nfs_mounts = get_current_local_nfs_mounts()
    logging.debug('Current local NFS mounts: %s', list(nfs_mounts.values()))

    state_files = get_state_files(state_file_dir)
    logging.debug('Current state files in "%s": %s', state_file_dir, list(state_files.values()))

    for mount, state_file in state_files.items():
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            try:
                state = json.load(f)
            except ValueError:
                logging.exception('Unable to parse json in %s', state_file_path)
                continue

542
543
544
545
546
547
        try:
            pid = state['pid']
            is_running = is_pid_running(pid)
        except KeyError:
            logging.debug('Did not find PID in state file. Assuming stunnel is not running')
            is_running = False
Max Beckett's avatar
Max Beckett committed
548
549
550
551
552

        current_time = time.time()
        if 'unmount_time' in state:
            if state['unmount_time'] + unmount_grace_period_sec < current_time:
                logging.info('Unmount grace period expired for %s', state_file)
553
                clean_up_mount_state(state_file_dir, state_file, state.get('pid'), is_running, state.get('mountStateDir'))
Max Beckett's avatar
Max Beckett committed
554
555
556
557
558
559

        elif mount not in nfs_mounts:
            logging.info('No mount found for "%s"', state_file)
            state = mark_as_unmounted(state, state_file_dir, state_file, current_time)

        else:
560
561
562
            if 'certificate' in state:
                check_certificate(config, state, state_file_dir, state_file)

Max Beckett's avatar
Max Beckett committed
563
564
565
            if is_running:
                logging.debug('TLS tunnel for %s is running', state_file)
            else:
Yuan Gao's avatar
Yuan Gao committed
566
                logging.warning('TLS tunnel for %s is not running', state_file)
Max Beckett's avatar
Max Beckett committed
567
568
569
570
571
572
573
                restart_tls_tunnel(child_procs, state, state_file_dir, state_file)


def check_child_procs(child_procs):
    for proc in child_procs:
        proc.poll()
        if proc.returncode is not None:
Yuan Gao's avatar
Yuan Gao committed
574
            logging.warning('Child TLS tunnel process %d has exited, returncode=%d', proc.pid, proc.returncode)
Max Beckett's avatar
Max Beckett committed
575
576
577
578
579
580
581
582
583
584
585
586
            child_procs.remove(proc)


def parse_arguments(args=None):
    if args is None:
        args = sys.argv

    if '-h' in args[1:] or '--help' in args[1:]:
        sys.stdout.write('Usage: %s [--version] [-h|--help]\n' % args[0])
        sys.exit(0)

    if '--version' in args[1:]:
Ian Patel's avatar
Ian Patel committed
587
        sys.stdout.write('%s Version: %s\n' % (args[0], VERSION))
Max Beckett's avatar
Max Beckett committed
588
589
590
591
        sys.exit(0)


def assert_root():
Ian Patel's avatar
Ian Patel committed
592
    if os.geteuid() != 0:
Max Beckett's avatar
Max Beckett committed
593
594
595
596
597
        sys.stderr.write('only root can run amazon-efs-mount-watchdog\n')
        sys.exit(1)


def read_config(config_file=CONFIG_FILE):
Yuan Gao's avatar
Yuan Gao committed
598
599
600
601
    try:
        p = ConfigParser.SafeConfigParser()
    except AttributeError:
        p = ConfigParser()
Max Beckett's avatar
Max Beckett committed
602
603
604
605
    p.read(config_file)
    return p


606
607
608
def check_certificate(config, state, state_file_dir, state_file, base_path=STATE_FILE_DIR):
    certificate_creation_time = datetime.strptime(state['certificateCreationTime'], CERT_DATETIME_FORMAT)
    certificate_exists = os.path.isfile(state['certificate'])
609
    certificate_renewal_interval_secs = get_certificate_renewal_interval_mins(config) * 60
610
    # creation instead of NOT_BEFORE datetime is used for refresh of cert because NOT_BEFORE derives from creation datetime
611
    should_refresh_cert = (get_utc_now() - certificate_creation_time).total_seconds() > certificate_renewal_interval_secs
612
613
614
615
616
617
618
619
620
621
622
623
624
625

    if certificate_exists and not should_refresh_cert:
        return

    ap_state = state.get('accessPoint')
    if ap_state and not AP_ID_RE.match(ap_state):
        logging.error('Access Point ID "%s" has been changed in the state file to a malformed format' % ap_state)
        return

    if not certificate_exists:
        logging.debug('Certificate (at %s) is missing. Recreating self-signed certificate' % state['certificate'])
    else:
        logging.debug('Refreshing self-signed certificate (at %s)' % state['certificate'])

626
    credentials_source = state.get('awsCredentialsMethod')
627
    updated_certificate_creation_time = recreate_certificate(config, state['mountStateDir'], state['commonName'], state['fsId'],
628
                                                             credentials_source, ap_state, state['region'], base_path=base_path)
629
630
631
632
633
    if updated_certificate_creation_time:
        state['certificateCreationTime'] = updated_certificate_creation_time
        rewrite_state_file(state, state_file_dir, state_file)

        # send SIGHUP to force a reload of the configuration file to trigger the stunnel process to notice the new certificate
634
635
636
637
        pid = state.get('pid')
        if is_pid_running(pid):
            process_group = os.getpgid(pid)
            logging.info('SIGHUP signal to stunnel. PID: %d, group ID: %s', pid, process_group)
638
            os.killpg(process_group, SIGHUP)
639
        else:
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
            logging.warning('TLS tunnel is not running for %s', state_file)


def create_required_directory(config, directory):
    mode = 0o750
    try:
        mode_str = config.get(CONFIG_SECTION, 'state_file_dir_mode')
        try:
            mode = int(mode_str, 8)
        except ValueError:
            logging.warning('Bad state_file_dir_mode "%s" in config file "%s"', mode_str, CONFIG_FILE)
    except NoOptionError:
        pass

    try:
        os.makedirs(directory, mode)
        logging.debug('Expected %s not found, recreating asset', directory)
    except OSError as e:
        if errno.EEXIST != e.errno or not os.path.isdir(directory):
            raise


662
663
664
665
666
667
668
669
670
671
672
673
def get_client_info(config):
    client_info = {}

    # source key/value pair in config file
    if config.has_option(CLIENT_INFO_SECTION, 'source'):
        client_source = config.get(CLIENT_INFO_SECTION, 'source')
        if 0 < len(client_source) <= CLIENT_SOURCE_STR_LEN_LIMIT:
            client_info['source'] = client_source

    return client_info


674
675
def recreate_certificate(config, mount_name, common_name, fs_id, credentials_source, ap_id, region,
                         base_path=STATE_FILE_DIR):
676
677
678
679
680
681
682
683
684
685
686
687
    current_time = get_utc_now()
    tls_paths = tls_paths_dictionary(mount_name, base_path)

    certificate_config = os.path.join(tls_paths['mount_dir'], 'config.conf')
    certificate_signing_request = os.path.join(tls_paths['mount_dir'], 'request.csr')
    certificate = os.path.join(tls_paths['mount_dir'], 'certificate.pem')

    ca_dirs_check(config, tls_paths['database_dir'], tls_paths['certs_dir'])
    ca_supporting_files_check(tls_paths['index'], tls_paths['index_attr'], tls_paths['serial'], tls_paths['rand'])

    private_key = check_and_create_private_key(base_path)

688
    if credentials_source:
689
690
691
        public_key = os.path.join(tls_paths['mount_dir'], 'publicKey.pem')
        create_public_key(private_key, public_key)

692
    client_info = get_client_info(config)
693
    config_body = create_ca_conf(certificate_config, common_name, tls_paths['mount_dir'], private_key, current_time, region,
694
                                 fs_id, credentials_source, ap_id=ap_id, client_info=client_info)
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

    if not config_body:
        logging.error('Cannot recreate self-signed certificate')
        return None

    create_certificate_signing_request(certificate_config, private_key, certificate_signing_request)

    not_before = get_certificate_timestamp(current_time, minutes=-NOT_BEFORE_MINS)
    not_after = get_certificate_timestamp(current_time, hours=NOT_AFTER_HOURS)

    cmd = 'openssl ca -startdate %s -enddate %s -selfsign -batch -notext -config %s -in %s -out %s' % \
          (not_before, not_after, certificate_config, certificate_signing_request, certificate)
    subprocess_call(cmd, 'Failed to create self-signed client-side certificate')
    return current_time.strftime(CERT_DATETIME_FORMAT)


def get_private_key_path():
    """Wrapped for mocking purposes in unit tests"""
    return PRIVATE_KEY_FILE


def check_and_create_private_key(base_path=STATE_FILE_DIR):
    # Creating RSA private keys is slow, so we will create one private key and allow mounts to share it.
    # This means, however, that we have to include a locking mechanism to ensure that the private key is
    # atomically created, as mounts occurring in parallel may try to create the key simultaneously.
    # The key should have been created during mounting, but the watchdog will recreate the private key if
    # it is missing.
    key = get_private_key_path()

    @contextmanager
    def open_lock_file():
        lock_file = os.path.join(base_path, 'efs-utils-lock')
        f = os.open(lock_file, os.O_CREAT | os.O_DSYNC | os.O_EXCL | os.O_RDWR)
        try:
            lock_file_contents = 'PID: %s' % os.getpid()
            os.write(f, lock_file_contents.encode('utf-8'))
            yield f
        finally:
            os.close(f)
            os.remove(lock_file)

    def do_with_lock(function):
        while True:
            try:
                with open_lock_file():
                    return function()
            except OSError as e:
                if e.errno == errno.EEXIST:
                    logging.info('Failed to take out private key creation lock, sleeping 50 ms')
                    time.sleep(0.05)
                else:
                    raise

    def generate_key():
        if os.path.isfile(key):
            return

        cmd = 'openssl genpkey -algorithm RSA -out %s -pkeyopt rsa_keygen_bits:3072' % key
        subprocess_call(cmd, 'Failed to create private key')
        read_only_mode = 0o400
        os.chmod(key, read_only_mode)

    do_with_lock(generate_key)
    return key


def create_certificate_signing_request(config_path, key_path, csr_path):
    cmd = 'openssl req -new -config %s -key %s -out %s' % (config_path, key_path, csr_path)
    subprocess_call(cmd, 'Failed to create certificate signing request (csr)')


766
def create_ca_conf(config_path, common_name, directory, private_key, date, region, fs_id, credentials_source,
767
                   ap_id=None, client_info=None):
768
769
    """Populate ca/req configuration file with fresh configurations at every mount since SigV4 signature can change"""
    public_key_path = os.path.join(directory, 'publicKey.pem')
770
    security_credentials = get_aws_security_credentials(credentials_source) if credentials_source else ''
771

772
773
    if credentials_source and security_credentials is None:
        logging.error('Failed to retrieve AWS security credentials using lookup method: %s', credentials_source)
774
775
        return None

776
    ca_extension_body = ca_extension_builder(ap_id, security_credentials, fs_id, client_info)
777
778
779
780
781
    efs_client_auth_body = efs_client_auth_builder(public_key_path, security_credentials['AccessKeyId'],
                                                   security_credentials['SecretAccessKey'], date, region, fs_id,
                                                   security_credentials['Token']) if credentials_source else ''
    if credentials_source and not efs_client_auth_body:
        logging.error('Failed to create AWS SigV4 signature section for OpenSSL config. Public Key path: %s', public_key_path)
782
        return None
783
784
785
    efs_client_info_body = efs_client_info_builder(client_info) if client_info else ''
    full_config_body = CA_CONFIG_BODY % (directory, private_key, common_name, ca_extension_body,
                                         efs_client_auth_body, efs_client_info_body)
786
787
788
789
790
791
792

    with open(config_path, 'w') as f:
        f.write(full_config_body)

    return full_config_body


793
def ca_extension_builder(ap_id, security_credentials, fs_id, client_info):
794
795
796
    ca_extension_str = '[ v3_ca ]\nsubjectKeyIdentifier = hash'
    if ap_id:
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.1 = ASN1:UTF8String:' + ap_id
797
    if security_credentials:
798
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.2 = ASN1:SEQUENCE:efs_client_auth'
799
800

    ca_extension_str += '\n1.3.6.1.4.1.4843.7.3 = ASN1:UTF8String:' + fs_id
801
802
    if client_info:
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.4 = ASN1:SEQUENCE:efs_client_info'
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826

    return ca_extension_str


def efs_client_auth_builder(public_key_path, access_key_id, secret_access_key, date, region, fs_id, session_token=None):
    public_key_hash = get_public_key_sha1(public_key_path)

    if not public_key_hash:
        return None

    canonical_request = create_canonical_request(public_key_hash, date, access_key_id, region, fs_id, session_token)
    string_to_sign = create_string_to_sign(canonical_request, date, region)
    signature = calculate_signature(string_to_sign, date, secret_access_key, region)
    efs_client_auth_str = '[ efs_client_auth ]'
    efs_client_auth_str += '\naccessKeyId = UTF8String:' + access_key_id
    efs_client_auth_str += '\nsignature = OCTETSTRING:' + signature
    efs_client_auth_str += '\nsigv4DateTime = UTCTIME:' + date.strftime(CERT_DATETIME_FORMAT)

    if session_token:
        efs_client_auth_str += '\nsessionToken = EXPLICIT:0,UTF8String:' + session_token

    return efs_client_auth_str


827
828
829
830
831
832
833
def efs_client_info_builder(client_info):
    efs_client_info_str = '[ efs_client_info ]'
    for key, value in client_info.items():
        efs_client_info_str += '\n%s = UTF8String: %s' % (key, value)
    return efs_client_info_str


834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
def create_public_key(private_key, public_key):
    cmd = 'openssl rsa -in %s -outform PEM -pubout -out %s' % (private_key, public_key)
    subprocess_call(cmd, 'Failed to create public key')


def subprocess_call(cmd, error_message):
    """Helper method to run shell openssl command and to handle response error messages"""
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
    (output, err) = process.communicate()
    rc = process.poll()
    if rc != 0:
        logging.debug('%s. Command %s failed, rc=%s, stdout="%s", stderr="%s"', error_message, cmd, rc, output, err)
    else:
        return output, err


def ca_dirs_check(config, database_dir, certs_dir):
    """Check if mount's database and certs directories exist and if not, create directories (also create all intermediate
    directories if they don't exist)."""
    if not os.path.exists(database_dir):
        create_required_directory(config, database_dir)
    if not os.path.exists(certs_dir):
        create_required_directory(config, certs_dir)


def ca_supporting_files_check(index_path, index_attr_path, serial_path, rand_path):
    """Create all supporting openssl ca and req files if they're not present in their respective directories"""
    def _recreate_file_warning(path):
        logging.warning('Expected %s not found, recreating file', path)

    if not os.path.isfile(index_path):
        open(index_path, 'w').close()
        _recreate_file_warning(index_path)
    if not os.path.isfile(index_attr_path):
        with open(index_attr_path, 'w+') as f:
            f.write('unique_subject = no')
        _recreate_file_warning(index_attr_path)
    if not os.path.isfile(serial_path):
        with open(serial_path, 'w+') as f:
            f.write('00')
        _recreate_file_warning(serial_path)
    if not os.path.isfile(rand_path):
        open(rand_path, 'w').close()
        _recreate_file_warning(rand_path)


def tls_paths_dictionary(mount_name, base_path=STATE_FILE_DIR):
    tls_dict = {
        'mount_dir': os.path.join(base_path, mount_name),
        'database_dir': os.path.join(base_path, mount_name, 'database'),
        'certs_dir': os.path.join(base_path, mount_name, 'certs'),
        'index': os.path.join(base_path, mount_name, 'database/index.txt'),
        'index_attr': os.path.join(base_path, mount_name, 'database/index.txt.attr'),
        'serial': os.path.join(base_path, mount_name, 'database/serial'),
        'rand': os.path.join(base_path, mount_name, 'database/.rand')
    }

    return tls_dict


def get_public_key_sha1(public_key):
    # truncating public key to remove the header and footer '-----(BEGIN|END) PUBLIC KEY-----'
    with open(public_key, 'r') as f:
        lines = f.readlines()
        lines = lines[1:-1]

    key = ''.join(lines)
    key = bytearray(base64.b64decode(key))

    # Parse the public key to pull out the actual key material by looking for the key BIT STRING
    # Example:
    #     0:d=0  hl=4 l= 418 cons: SEQUENCE
    #     4:d=1  hl=2 l=  13 cons: SEQUENCE
    #     6:d=2  hl=2 l=   9 prim: OBJECT            :rsaEncryption
    #    17:d=2  hl=2 l=   0 prim: NULL
    #    19:d=1  hl=4 l= 399 prim: BIT STRING
    cmd = 'openssl asn1parse -inform PEM -in %s' % public_key
    output, err = subprocess_call(cmd, 'Unable to ASN1 parse public key file, %s, correctly' % public_key)

    key_line = ''
    for line in output.splitlines():
        if 'BIT STRING' in line.decode('utf-8'):
            key_line = line.decode('utf-8')

    if not key_line:
        logging.error('Public key file, %s, is incorrectly formatted', public_key)
        return None

    key_line = key_line.replace(' ', '')

    # DER encoding TLV (Tag, Length, Value)
    # - the first octet (byte) is the tag (type)
    # - the next octets are the length - "definite form"
    #   - the first octet always has the high order bit (8) set to 1
    #   - the remaining 127 bits are used to encode the number of octets that follow
    #   - the following octets encode, as big-endian, the length (which may be 0) as a number of octets
    # - the remaining octets are the "value" aka content
    #
    # For a BIT STRING, the first octet of the value is used to signify the number of unused bits that exist in the last
    # content byte. Note that this is explicitly excluded from the SubjectKeyIdentifier hash, per
    # https://tools.ietf.org/html/rfc5280#section-4.2.1.2
    #
    # Example:
    #   0382018f00...<subjectPublicKey>
    #   - 03 - BIT STRING tag
    #   - 82 - 2 length octets to follow (ignore high order bit)
    #   - 018f - length of 399
    #   - 00 - no unused bits in the last content byte
    offset = int(key_line.split(':')[0])
    key = key[offset:]

    num_length_octets = key[1] & 0b01111111

    # Exclude the tag (1), length (1 + num_length_octets), and number of unused bits (1)
    offset = 1 + 1 + num_length_octets + 1
    key = key[offset:]

    sha1 = hashlib.sha1()
    sha1.update(key)

    return sha1.hexdigest()


def create_canonical_request(public_key_hash, date, access_key, region, fs_id, session_token=None):
    """
    Create a Canonical Request - https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
    """
    formatted_datetime = date.strftime(SIGV4_DATETIME_FORMAT)
    credential = quote_plus(access_key + '/' + get_credential_scope(date, region))

    request = HTTP_REQUEST_METHOD + '\n'
    request += CANONICAL_URI + '\n'
    request += create_canonical_query_string(public_key_hash, credential, formatted_datetime, session_token) + '\n'
    request += CANONICAL_HEADERS % fs_id + '\n'
    request += SIGNED_HEADERS + '\n'

    sha256 = hashlib.sha256()
    sha256.update(REQUEST_PAYLOAD.encode())
    request += sha256.hexdigest()

    return request


def create_canonical_query_string(public_key_hash, credential, formatted_datetime, session_token=None):
    canonical_query_params = {
        'Action': 'Connect',
        # Public key hash is included in canonical request to tie the signature to a specific key pair to avoid replay attacks
        'PublicKeyHash': quote_plus(public_key_hash),
        'X-Amz-Algorithm': ALGORITHM,
        'X-Amz-Credential': credential,
        'X-Amz-Date': quote_plus(formatted_datetime),
        'X-Amz-Expires': 86400,
        'X-Amz-SignedHeaders': SIGNED_HEADERS,
    }

    if session_token:
        canonical_query_params['X-Amz-Security-Token'] = quote_plus(session_token)

    # Cannot use urllib.urlencode because it replaces the %s's
    return '&'.join(['%s=%s' % (k, v) for k, v in sorted(canonical_query_params.items())])


def create_string_to_sign(canonical_request, date, region):
    """
    Create a String to Sign - https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
    """
    string_to_sign = ALGORITHM + '\n'
    string_to_sign += date.strftime(SIGV4_DATETIME_FORMAT) + '\n'
    string_to_sign += get_credential_scope(date, region) + '\n'

    sha256 = hashlib.sha256()
    sha256.update(canonical_request.encode())
    string_to_sign += sha256.hexdigest()

    return string_to_sign


def calculate_signature(string_to_sign, date, secret_access_key, region):
    """
    Calculate the Signature - https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html
    """
    def _sign(key, msg):
        return hmac.new(key, msg.encode('utf-8'), hashlib.sha256)

    key_date = _sign(('AWS4' + secret_access_key).encode('utf-8'), date.strftime(DATE_ONLY_FORMAT)).digest()
    add_region = _sign(key_date, region).digest()
    add_service = _sign(add_region, SERVICE).digest()
    signing_key = _sign(add_service, 'aws4_request').digest()

    return _sign(signing_key, string_to_sign).hexdigest()


1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
def get_certificate_renewal_interval_mins(config):
    interval = DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN
    try:
        mins_from_config = config.get(CONFIG_SECTION, 'tls_cert_renewal_interval_min')
        try:
            if int(mins_from_config) > 0:
                interval = int(mins_from_config)
            else:
                logging.warning('tls_cert_renewal_interval_min value in config file "%s" is lower than 1 minute. Defaulting '
                                'to %d minutes.', CONFIG_FILE, DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN)
        except ValueError:
            logging.warning('Bad tls_cert_renewal_interval_min value, "%s", in config file "%s". Defaulting to %d minutes.',
                            mins_from_config, CONFIG_FILE, DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN)
    except NoOptionError:
        logging.warning('No tls_cert_renewal_interval_min value in config file "%s". Defaulting to %d minutes.', CONFIG_FILE,
                        DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN)

    return interval


1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
def get_credential_scope(date, region):
    return '/'.join([date.strftime(DATE_ONLY_FORMAT), region, SERVICE, AWS4_REQUEST])


def get_certificate_timestamp(current_time, **kwargs):
    updated_time = current_time + timedelta(**kwargs)
    return updated_time.strftime(CERT_DATETIME_FORMAT)


def get_utc_now():
    """
    Wrapped for patching purposes in unit tests
    """
    return datetime.utcnow()


1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
def check_process_name(pid):
    cmd = ['cat', '/proc/{pid}/cmdline'.format(pid=pid)]

    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
    return p.communicate()[0]


def clean_up_previous_stunnel_pids(state_file_dir=STATE_FILE_DIR):
    """
    Cleans up stunnel pids created by mount watchdog spawned by a previous efs-csi-driver after driver restart, upgrade
    or crash. This method attempts to clean PIDs from persisted state files after efs-csi-driver restart to
    ensure watchdog creates a new stunnel.
    """
    state_files = get_state_files(state_file_dir)
    logging.debug('Persisted state files in "%s": %s', state_file_dir, list(state_files.values()))

    for state_file in state_files.values():
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            try:
                state = json.load(f)
            except ValueError:
                logging.exception('Unable to parse json in %s', state_file_path)
                continue

            try:
                pid = state['pid']
            except KeyError:
                logging.debug('No PID found in state file %s', state_file)
                continue

            out = check_process_name(pid)

            if out and 'stunnel' in str(out):
                logging.debug('PID %s in state file %s is active. Skipping clean up', pid, state_file)
                continue

            state.pop('pid')
            logging.debug('Cleaning up pid %s in state file %s', pid, state_file)

            rewrite_state_file(state, state_file_dir, state_file)


Max Beckett's avatar
Max Beckett committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
def main():
    parse_arguments()
    assert_root()

    config = read_config()
    bootstrap_logging(config)

    child_procs = []

    if config.getboolean(CONFIG_SECTION, 'enabled'):
1115
        logging.info('amazon-efs-mount-watchdog, version %s, is enabled and started', VERSION)
Max Beckett's avatar
Max Beckett committed
1116
1117
1118
        poll_interval_sec = config.getint(CONFIG_SECTION, 'poll_interval_sec')
        unmount_grace_period_sec = config.getint(CONFIG_SECTION, 'unmount_grace_period_sec')

1119
1120
        clean_up_previous_stunnel_pids()

Max Beckett's avatar
Max Beckett committed
1121
        while True:
1122
            config = read_config()
1123
            check_efs_mounts(config, child_procs, unmount_grace_period_sec)
Max Beckett's avatar
Max Beckett committed
1124
1125
1126
1127
1128
1129
1130
1131
1132
            check_child_procs(child_procs)

            time.sleep(poll_interval_sec)
    else:
        logging.info('amazon-efs-mount-watchdog is not enabled')


if '__main__' == __name__:
    main()