__init__.py 43.3 KB
Newer Older
1
#!/usr/bin/env python3
Max Beckett's avatar
Max Beckett committed
2
3
4
5
6
7
8
9
#
# Copyright 2017-2018 Amazon.com, Inc. and its affiliates. All Rights Reserved.
#
# Licensed under the MIT License. See the LICENSE accompanying this file
# for the specific language governing permissions and limitations under
# the License.
#

10
import base64
Max Beckett's avatar
Max Beckett committed
11
import errno
12
13
import hashlib
import hmac
Max Beckett's avatar
Max Beckett committed
14
15
16
17
import json
import logging
import logging.handlers
import os
18
19
20
import pwd
import re
import shutil
Max Beckett's avatar
Max Beckett committed
21
22
23
24
25
import subprocess
import sys
import time

from collections import namedtuple
26
27
from contextlib import contextmanager
from datetime import datetime, timedelta
Max Beckett's avatar
Max Beckett committed
28
from logging.handlers import RotatingFileHandler
29
from signal import SIGTERM, SIGHUP
Max Beckett's avatar
Max Beckett committed
30
31

try:
32
33
    from configparser import ConfigParser, NoOptionError, NoSectionError
except ImportError:
Max Beckett's avatar
Max Beckett committed
34
    import ConfigParser
35
    from ConfigParser import NoOptionError, NoSectionError
Max Beckett's avatar
Max Beckett committed
36

37
38
39
40
41
42
try:
    from urllib.parse import quote_plus
except ImportError:
    from urllib import quote_plus

try:
43
    from urllib.request import urlopen, Request
44
    from urllib.error import URLError, HTTPError
45
    from urllib.parse import urlencode
46
47
48
except ImportError:
    from urllib2 import URLError, HTTPError, build_opener, urlopen, Request, HTTPHandler
    from urllib import urlencode
49

50

51
VERSION = '1.29.1'
52
SERVICE = 'elasticfilesystem'
Max Beckett's avatar
Max Beckett committed
53
54
55

CONFIG_FILE = '/etc/amazon/efs/efs-utils.conf'
CONFIG_SECTION = 'mount-watchdog'
56
57
CLIENT_INFO_SECTION = 'client-info'
CLIENT_SOURCE_STR_LEN_LIMIT = 100
58
DEFAULT_UNKNOWN_VALUE = 'unknown'
Max Beckett's avatar
Max Beckett committed
59
60
61
62
63
64

LOG_DIR = '/var/log/amazon/efs'
LOG_FILE = 'mount-watchdog.log'

STATE_FILE_DIR = '/var/run/efs'

65
PRIVATE_KEY_FILE = '/etc/amazon/efs/privateKey.pem'
66
DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN = 60
67
68
69
70
71
72
NOT_BEFORE_MINS = 15
NOT_AFTER_HOURS = 3
DATE_ONLY_FORMAT = '%Y%m%d'
SIGV4_DATETIME_FORMAT = '%Y%m%dT%H%M%SZ'
CERT_DATETIME_FORMAT = '%y%m%d%H%M%SZ'

73
74
75
76
77
AWS_CREDENTIALS_FILES = {
    'credentials': os.path.expanduser(os.path.join('~' + pwd.getpwuid(os.getuid()).pw_name, '.aws', 'credentials')),
    'config': os.path.expanduser(os.path.join('~' + pwd.getpwuid(os.getuid()).pw_name, '.aws', 'config')),
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
CA_CONFIG_BODY = """dir = %s
RANDFILE = $dir/database/.rand

[ ca ]
default_ca = local_ca

[ local_ca ]
database = $dir/database/index.txt
serial = $dir/database/serial
private_key = %s
cert = $dir/certificate.pem
new_certs_dir = $dir/certs
default_md = sha256
preserve = no
policy = efsPolicy
x509_extensions = v3_ca

[ efsPolicy ]
CN = supplied

[ req ]
prompt = no
distinguished_name = req_distinguished_name

[ req_distinguished_name ]
CN = %s

%s

107
108
%s

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
%s
"""

# SigV4 Auth
ALGORITHM = 'AWS4-HMAC-SHA256'
AWS4_REQUEST = 'aws4_request'

HTTP_REQUEST_METHOD = 'GET'
CANONICAL_URI = '/'
CANONICAL_HEADERS_DICT = {
    'host': '%s'
}
CANONICAL_HEADERS = '\n'.join(['%s:%s' % (k, v) for k, v in sorted(CANONICAL_HEADERS_DICT.items())])
SIGNED_HEADERS = ';'.join(CANONICAL_HEADERS_DICT.keys())
REQUEST_PAYLOAD = ''

AP_ID_RE = re.compile('^fsap-[0-9a-f]{17}$')

127
ECS_TASK_METADATA_API = 'http://169.254.170.2'
128
STS_ENDPOINT_URL = 'https://sts.amazonaws.com/'
129
INSTANCE_IAM_URL = 'http://169.254.169.254/latest/meta-data/iam/security-credentials/'
130
INSTANCE_METADATA_TOKEN_URL = 'http://169.254.169.254/latest/api/token'
131
SECURITY_CREDS_ECS_URI_HELP_URL = 'https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html'
132
SECURITY_CREDS_WEBIDENTITY_HELP_URL = 'https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html'
133
134
SECURITY_CREDS_IAM_ROLE_HELP_URL = 'https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html'

Max Beckett's avatar
Max Beckett committed
135
136
137
138
139
140
141
142
143
144
145
146
Mount = namedtuple('Mount', ['server', 'mountpoint', 'type', 'options', 'freq', 'passno'])


def fatal_error(user_message, log_message=None):
    if log_message is None:
        log_message = user_message

    sys.stderr.write('%s\n' % user_message)
    logging.error(log_message)
    sys.exit(1)


147
def get_aws_security_credentials(credentials_source):
148
149
150
151
152
    """
    Lookup AWS security credentials (access key ID and secret access key). Adapted credentials provider chain from:
    https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html and
    https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html
    """
153
154
155
156
157
158
159
160
    method, value = credentials_source.split(':', 1)

    if method == 'credentials':
        return get_aws_security_credentials_from_file('credentials', value)
    elif method == 'config':
        return get_aws_security_credentials_from_file('config', value)
    elif method == 'ecs':
        return get_aws_security_credentials_from_ecs(value)
161
162
    elif method == 'webidentity':
        return get_aws_security_credentials_from_webidentity(*(value.split(',')))
163
164
165
166
167
168
    elif method == 'metadata':
        return get_aws_security_credentials_from_instance_metadata()
    else:
        logging.error('Improper credentials source string "%s" found from mount state file', credentials_source)
        return None

169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def get_aws_ec2_metadata_token():
    try:
        opener = build_opener(HTTPHandler)
        request = Request(INSTANCE_METADATA_TOKEN_URL)
        request.add_header('X-aws-ec2-metadata-token-ttl-seconds', 21600)
        request.get_method = lambda: 'PUT'
        res = opener.open(request)
        return res.read()
    except NameError:
        headers = {'X-aws-ec2-metadata-token-ttl-seconds': 21600}
        req = Request(INSTANCE_METADATA_TOKEN_URL, headers=headers, method='PUT')
        res = urlopen(req)
        return res.read()


185
186
187
188
189
def get_aws_security_credentials_from_file(file_name, awsprofile):
    # attempt to lookup AWS security credentials in AWS credentials file (~/.aws/credentials) and configs file (~/.aws/config)
    file_path = AWS_CREDENTIALS_FILES.get(file_name)
    if file_path and os.path.exists(file_path):
        credentials = credentials_file_helper(file_path, awsprofile)
190
191
192
        if credentials['AccessKeyId']:
            return credentials

193
194
    logging.error('AWS security credentials not found in %s under named profile [%s]', file_path, awsprofile)
    return None
195
196


197
def get_aws_security_credentials_from_ecs(uri):
198
    # through ECS security credentials uri found in AWS_CONTAINER_CREDENTIALS_RELATIVE_URI environment variable
199
200
201
    dict_keys = ['AccessKeyId', 'SecretAccessKey', 'Token']
    ecs_uri = ECS_TASK_METADATA_API + uri
    ecs_unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % ecs_uri
Yuan Gao's avatar
Yuan Gao committed
202
203
    ecs_url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.' % \
                        (ecs_uri, SECURITY_CREDS_ECS_URI_HELP_URL)
204
205
206
207
208
209
    ecs_security_dict = url_request_helper(ecs_uri, ecs_unsuccessful_resp, ecs_url_error_msg)

    if ecs_security_dict and all(k in ecs_security_dict for k in dict_keys):
        return ecs_security_dict

    return None
210
211


212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def get_aws_security_credentials_from_webidentity(role_arn, token_file):
    try:
        with open(token_file, 'r') as f:
            token = f.read()
    except Exception as e:
        logging.error('Error reading token file %s: %s', token_file, e)
        return None

    webidentity_url = STS_ENDPOINT_URL + '?' + urlencode({
        'Version': '2011-06-15',
        'Action': 'AssumeRoleWithWebIdentity',
        'RoleArn': role_arn,
        'RoleSessionName': 'efs-mount-helper',
        'WebIdentityToken': token
    })

    unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % STS_ENDPOINT_URL
    url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.' % \
                    (STS_ENDPOINT_URL, SECURITY_CREDS_WEBIDENTITY_HELP_URL)
    resp = url_request_helper(webidentity_url, unsuccessful_resp, url_error_msg, headers={'Accept': 'application/json'})

    if resp:
        creds = resp \
                .get('AssumeRoleWithWebIdentityResponse', {}) \
                .get('AssumeRoleWithWebIdentityResult', {}) \
                .get('Credentials', {})
        if all(k in creds for k in ['AccessKeyId', 'SecretAccessKey', 'SessionToken']):
            return {
                'AccessKeyId': creds['AccessKeyId'],
                'SecretAccessKey': creds['SecretAccessKey'],
                'Token': creds['SessionToken']
            }

    return None


248
def get_aws_security_credentials_from_instance_metadata():
249
    # through IAM role name security credentials lookup uri (after lookup for IAM role name attached to instance)
250
    dict_keys = ['AccessKeyId', 'SecretAccessKey', 'Token']
251
    iam_role_unsuccessful_resp = 'Unsuccessful retrieval of IAM role name at %s.' % INSTANCE_IAM_URL
Yuan Gao's avatar
Yuan Gao committed
252
253
    iam_role_url_error_msg = 'Unable to reach %s to retrieve IAM role name. See %s for more info.' % \
                             (INSTANCE_IAM_URL, SECURITY_CREDS_IAM_ROLE_HELP_URL)
254
255
    iam_role_name = url_request_helper(INSTANCE_IAM_URL, iam_role_unsuccessful_resp,
                                       iam_role_url_error_msg, retry_with_new_header_token=True)
256
    if iam_role_name:
Yuan Gao's avatar
Yuan Gao committed
257
        security_creds_lookup_url = INSTANCE_IAM_URL + iam_role_name
258
        unsuccessful_resp = 'Unsuccessful retrieval of AWS security credentials at %s.' % security_creds_lookup_url
Yuan Gao's avatar
Yuan Gao committed
259
260
        url_error_msg = 'Unable to reach %s to retrieve AWS security credentials. See %s for more info.' % \
                        (security_creds_lookup_url, SECURITY_CREDS_IAM_ROLE_HELP_URL)
261
262
        iam_security_dict = url_request_helper(security_creds_lookup_url, unsuccessful_resp,
                                               url_error_msg, retry_with_new_header_token=True)
263

264
        if iam_security_dict and all(k in iam_security_dict for k in dict_keys):
265
266
            return iam_security_dict

267
    return None
268
269


270
def credentials_file_helper(file_path, awsprofile):
271
272
273
274
    aws_credentials_configs = read_config(file_path)
    credentials = {'AccessKeyId': None, 'SecretAccessKey': None, 'Token': None}

    try:
275
276
277
        aws_access_key_id = aws_credentials_configs.get(awsprofile, 'aws_access_key_id')
        secret_access_key = aws_credentials_configs.get(awsprofile, 'aws_secret_access_key')
        session_token = aws_credentials_configs.get(awsprofile, 'aws_session_token')
278

279
280
        credentials['AccessKeyId'] = aws_access_key_id
        credentials['SecretAccessKey'] = secret_access_key
281
282
283
        credentials['Token'] = session_token
    except NoOptionError as e:
        if 'aws_access_key_id' in str(e) or 'aws_secret_access_key' in str(e):
284
285
            logging.debug('aws_access_key_id or aws_secret_access_key not found in %s under named profile [%s]', file_path,
                          awsprofile)
286
287
        if 'aws_session_token' in str(e):
            logging.debug('aws_session_token not found in %s', file_path)
288
289
            credentials['AccessKeyId'] = aws_credentials_configs.get(awsprofile, 'aws_access_key_id')
            credentials['SecretAccessKey'] = aws_credentials_configs.get(awsprofile, 'aws_secret_access_key')
290
    except NoSectionError:
291
        logging.debug('No [%s] section found in config file %s', awsprofile, file_path)
292
293
294
295

    return credentials


296
def url_request_helper(url, unsuccessful_resp, url_error_msg, headers={}, retry_with_new_header_token=False):
297
    try:
298
299
300
301
        req = Request(url)
        for k, v in headers.items():
            req.add_header(k, v)
        request_resp = urlopen(req, timeout=1)
302

303
304
305
306
307
308
309
310
311
312
        return get_resp_obj(request_resp, url, unsuccessful_resp)
    except HTTPError as e:
        # For instance enable with IMDSv2, Unauthorized 401 error will be thrown,
        # to retrieve metadata, the header should embeded with metadata token
        if e.code == 401 and retry_with_new_header_token:
            token = get_aws_ec2_metadata_token()
            req.add_header('X-aws-ec2-metadata-token', token)
            request_resp = urlopen(req, timeout=1)
            return get_resp_obj(request_resp, url, unsuccessful_resp)
        err_msg = 'Unable to reach the url at %s: status=%d, reason is %s' % (url, e.code, e.reason)
313
    except URLError as e:
314
315
316
317
318
319
320
321
322
323
        err_msg = 'Unable to reach the url at %s, reason is %s' % (url, e.reason)

    if err_msg:
        logging.debug('%s %s', url_error_msg, err_msg)
    return None


def get_resp_obj(request_resp, url, unsuccessful_resp):
    if request_resp.getcode() != 200:
        logging.debug(unsuccessful_resp + ' %s: ResponseCode=%d', url, request_resp.getcode())
324
325
        return None

326
327
328
329
330
331
332
333
334
335
336
337
338
    resp_body = request_resp.read()
    resp_body_type = type(resp_body)
    try:
        if resp_body_type is str:
            resp_dict = json.loads(resp_body)
        else:
            resp_dict = json.loads(resp_body.decode(request_resp.headers.get_content_charset() or 'us-ascii'))

        return resp_dict
    except ValueError as e:
        logging.info('ValueError parsing "%s" into json: %s. Returning response body.' % (str(resp_body), e))
        return resp_body if resp_body_type is str else resp_body.decode('utf-8')

339

Max Beckett's avatar
Max Beckett committed
340
def bootstrap_logging(config, log_dir=LOG_DIR):
Ian Patel's avatar
Ian Patel committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    raw_level = config.get(CONFIG_SECTION, 'logging_level')
    levels = {
        'debug': logging.DEBUG,
        'info': logging.INFO,
        'warning': logging.WARNING,
        'error': logging.ERROR,
        'critical': logging.CRITICAL
    }
    level = levels.get(raw_level.lower())
    level_error = False

    if not level:
        # delay logging error about malformed log level until after logging is configured
        level_error = True
        level = logging.INFO

Max Beckett's avatar
Max Beckett committed
357
358
359
360
361
362
363
364
365
366
    max_bytes = config.getint(CONFIG_SECTION, 'logging_max_bytes')
    file_count = config.getint(CONFIG_SECTION, 'logging_file_count')

    handler = RotatingFileHandler(os.path.join(log_dir, LOG_FILE), maxBytes=max_bytes, backupCount=file_count)
    handler.setFormatter(logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s'))

    logger = logging.getLogger()
    logger.setLevel(level)
    logger.addHandler(handler)

Ian Patel's avatar
Ian Patel committed
367
368
369
    if level_error:
        logging.error('Malformed logging level "%s", setting logging level to %s', raw_level, level)

Max Beckett's avatar
Max Beckett committed
370

371
372
373
374
375
376
377
378
379
380
381
def parse_options(options):
    opts = {}
    for o in options.split(','):
        if '=' in o:
            k, v = o.split('=')
            opts[k] = v
        else:
            opts[o] = None
    return opts


382
383
def get_file_safe_mountpoint(mount):
    mountpoint = os.path.abspath(mount.mountpoint).replace(os.sep, '.')
Max Beckett's avatar
Max Beckett committed
384
385
    if mountpoint.startswith('.'):
        mountpoint = mountpoint[1:]
386

387
    opts = parse_options(mount.options)
388
389
390
    if 'port' not in opts:
        # some other localhost nfs mount not running over stunnel
        return None
391
    return mountpoint + '.' + opts['port']
Max Beckett's avatar
Max Beckett committed
392
393
394
395


def get_current_local_nfs_mounts(mount_file='/proc/mounts'):
    """
396
397
    Return a dict of the current NFS mounts for servers running on localhost, keyed by the mountpoint and port as it
    appears in EFS watchdog state files.
Max Beckett's avatar
Max Beckett committed
398
399
400
401
402
403
404
405
406
407
408
    """
    mounts = []

    with open(mount_file) as f:
        for mount in f:
            mounts.append(Mount._make(mount.strip().split()))

    mounts = [m for m in mounts if m.server.startswith('127.0.0.1') and 'nfs' in m.type]

    mount_dict = {}
    for m in mounts:
409
410
411
        safe_mnt = get_file_safe_mountpoint(m)
        if safe_mnt:
            mount_dict[safe_mnt] = m
Max Beckett's avatar
Max Beckett committed
412
413
414
415
416

    return mount_dict


def get_state_files(state_file_dir):
417
418
419
    """
    Return a dict of the absolute path of state files in state_file_dir, keyed by the mountpoint and port portion of the filename.
    """
Max Beckett's avatar
Max Beckett committed
420
421
422
423
    state_files = {}

    if os.path.isdir(state_file_dir):
        for sf in os.listdir(state_file_dir):
424
            if not sf.startswith('fs-') or os.path.isdir(os.path.join(state_file_dir, sf)):
Max Beckett's avatar
Max Beckett committed
425
426
                continue

427
428
            # This translates the state file name "fs-deadbeaf.home.user.mnt.12345"
            # into file-safe mountpoint "home.user.mnt.12345"
Max Beckett's avatar
Max Beckett committed
429
            first_period = sf.find('.')
430
431
432
            mount_point_and_port = sf[first_period + 1:]
            logging.debug('Translating "%s" into mount point and port "%s"', sf, mount_point_and_port)
            state_files[mount_point_and_port] = sf
Max Beckett's avatar
Max Beckett committed
433
434
435
436
437

    return state_files


def is_pid_running(pid):
438
439
    if not pid:
        return False
Max Beckett's avatar
Max Beckett committed
440
441
442
443
444
445
446
447
448
449
    try:
        os.kill(pid, 0)
        return True
    except OSError:
        return False


def start_tls_tunnel(child_procs, state_file, command):
    # launch the tunnel in a process group so if it has any child processes, they can be killed easily
    logging.info('Starting TLS tunnel: "%s"', ' '.join(command))
450
    tunnel = subprocess.Popen(command, preexec_fn=os.setsid, close_fds=True)
Max Beckett's avatar
Max Beckett committed
451
452
453
454
455
456
457
458
459
460

    if not is_pid_running(tunnel.pid):
        fatal_error('Failed to initialize TLS tunnel for %s' % state_file, 'Failed to start TLS tunnel.')

    logging.info('Started TLS tunnel, pid: %d', tunnel.pid)

    child_procs.append(tunnel)
    return tunnel.pid


461
def clean_up_mount_state(state_file_dir, state_file, pid, is_running, mount_state_dir=None):
Max Beckett's avatar
Max Beckett committed
462
463
464
465
466
467
468
469
    if is_running:
        process_group = os.getpgid(pid)
        logging.info('Terminating running TLS tunnel - PID: %d, group ID: %s', pid, process_group)
        os.killpg(process_group, SIGTERM)

    if is_pid_running(pid):
        logging.info('TLS tunnel: %d is still running, will retry termination', pid)
    else:
470
471
472
473
        if not pid:
            logging.info('TLS tunnel has been killed, cleaning up state')
        else:
            logging.info('TLS tunnel: %d is no longer running, cleaning up state', pid)
Max Beckett's avatar
Max Beckett committed
474
475
476
477
478
479
480
481
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            state = json.load(f)

        for f in state.get('files', list()):
            logging.debug('Deleting %s', f)
            try:
                os.remove(f)
482
                logging.debug('Deleted %s', f)
Max Beckett's avatar
Max Beckett committed
483
484
485
486
487
488
            except OSError as e:
                if e.errno != errno.ENOENT:
                    raise

        os.remove(state_file_path)

489
490
491
492
493
494
495
496
        if mount_state_dir is not None:
            mount_state_dir_abs_path = os.path.join(state_file_dir, mount_state_dir)
            if os.path.isdir(mount_state_dir_abs_path):
                shutil.rmtree(mount_state_dir_abs_path)
            else:
                logging.debug('Attempt to remove mount state directory %s failed. Directory is not present.',
                              mount_state_dir_abs_path)

Max Beckett's avatar
Max Beckett committed
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

def rewrite_state_file(state, state_file_dir, state_file):
    tmp_state_file = os.path.join(state_file_dir, '~%s' % state_file)
    with open(tmp_state_file, 'w') as f:
        json.dump(state, f)

    os.rename(tmp_state_file, os.path.join(state_file_dir, state_file))


def mark_as_unmounted(state, state_file_dir, state_file, current_time):
    logging.debug('Marking %s as unmounted at %d', state_file, current_time)
    state['unmount_time'] = current_time

    rewrite_state_file(state, state_file_dir, state_file)

    return state


def restart_tls_tunnel(child_procs, state, state_file_dir, state_file):
516
517
518
519
    if 'certificate' in state and not os.path.exists(state['certificate']):
        logging.error('Cannot restart stunnel because self-signed certificate at %s is missing' % state['certificate'])
        return

Max Beckett's avatar
Max Beckett committed
520
521
522
523
524
525
526
    new_tunnel_pid = start_tls_tunnel(child_procs, state_file, state['cmd'])
    state['pid'] = new_tunnel_pid

    logging.debug('Rewriting %s with new pid: %d', state_file, new_tunnel_pid)
    rewrite_state_file(state, state_file_dir, state_file)


527
def check_efs_mounts(config, child_procs, unmount_grace_period_sec, state_file_dir=STATE_FILE_DIR):
Max Beckett's avatar
Max Beckett committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    nfs_mounts = get_current_local_nfs_mounts()
    logging.debug('Current local NFS mounts: %s', list(nfs_mounts.values()))

    state_files = get_state_files(state_file_dir)
    logging.debug('Current state files in "%s": %s', state_file_dir, list(state_files.values()))

    for mount, state_file in state_files.items():
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            try:
                state = json.load(f)
            except ValueError:
                logging.exception('Unable to parse json in %s', state_file_path)
                continue

543
544
545
546
547
548
        try:
            pid = state['pid']
            is_running = is_pid_running(pid)
        except KeyError:
            logging.debug('Did not find PID in state file. Assuming stunnel is not running')
            is_running = False
Max Beckett's avatar
Max Beckett committed
549
550
551
552
553

        current_time = time.time()
        if 'unmount_time' in state:
            if state['unmount_time'] + unmount_grace_period_sec < current_time:
                logging.info('Unmount grace period expired for %s', state_file)
554
                clean_up_mount_state(state_file_dir, state_file, state.get('pid'), is_running, state.get('mountStateDir'))
Max Beckett's avatar
Max Beckett committed
555
556
557
558
559
560

        elif mount not in nfs_mounts:
            logging.info('No mount found for "%s"', state_file)
            state = mark_as_unmounted(state, state_file_dir, state_file, current_time)

        else:
561
562
563
            if 'certificate' in state:
                check_certificate(config, state, state_file_dir, state_file)

Max Beckett's avatar
Max Beckett committed
564
565
566
            if is_running:
                logging.debug('TLS tunnel for %s is running', state_file)
            else:
Yuan Gao's avatar
Yuan Gao committed
567
                logging.warning('TLS tunnel for %s is not running', state_file)
Max Beckett's avatar
Max Beckett committed
568
569
570
571
572
573
574
                restart_tls_tunnel(child_procs, state, state_file_dir, state_file)


def check_child_procs(child_procs):
    for proc in child_procs:
        proc.poll()
        if proc.returncode is not None:
Yuan Gao's avatar
Yuan Gao committed
575
            logging.warning('Child TLS tunnel process %d has exited, returncode=%d', proc.pid, proc.returncode)
Max Beckett's avatar
Max Beckett committed
576
577
578
579
580
581
582
583
584
585
586
587
            child_procs.remove(proc)


def parse_arguments(args=None):
    if args is None:
        args = sys.argv

    if '-h' in args[1:] or '--help' in args[1:]:
        sys.stdout.write('Usage: %s [--version] [-h|--help]\n' % args[0])
        sys.exit(0)

    if '--version' in args[1:]:
Ian Patel's avatar
Ian Patel committed
588
        sys.stdout.write('%s Version: %s\n' % (args[0], VERSION))
Max Beckett's avatar
Max Beckett committed
589
590
591
592
        sys.exit(0)


def assert_root():
Ian Patel's avatar
Ian Patel committed
593
    if os.geteuid() != 0:
Max Beckett's avatar
Max Beckett committed
594
595
596
597
598
        sys.stderr.write('only root can run amazon-efs-mount-watchdog\n')
        sys.exit(1)


def read_config(config_file=CONFIG_FILE):
Yuan Gao's avatar
Yuan Gao committed
599
600
601
602
    try:
        p = ConfigParser.SafeConfigParser()
    except AttributeError:
        p = ConfigParser()
Max Beckett's avatar
Max Beckett committed
603
604
605
606
    p.read(config_file)
    return p


607
608
609
def check_certificate(config, state, state_file_dir, state_file, base_path=STATE_FILE_DIR):
    certificate_creation_time = datetime.strptime(state['certificateCreationTime'], CERT_DATETIME_FORMAT)
    certificate_exists = os.path.isfile(state['certificate'])
610
    certificate_renewal_interval_secs = get_certificate_renewal_interval_mins(config) * 60
611
    # creation instead of NOT_BEFORE datetime is used for refresh of cert because NOT_BEFORE derives from creation datetime
612
    should_refresh_cert = (get_utc_now() - certificate_creation_time).total_seconds() > certificate_renewal_interval_secs
613
614
615
616
617
618
619
620
621
622
623
624
625
626

    if certificate_exists and not should_refresh_cert:
        return

    ap_state = state.get('accessPoint')
    if ap_state and not AP_ID_RE.match(ap_state):
        logging.error('Access Point ID "%s" has been changed in the state file to a malformed format' % ap_state)
        return

    if not certificate_exists:
        logging.debug('Certificate (at %s) is missing. Recreating self-signed certificate' % state['certificate'])
    else:
        logging.debug('Refreshing self-signed certificate (at %s)' % state['certificate'])

627
    credentials_source = state.get('awsCredentialsMethod')
628
    updated_certificate_creation_time = recreate_certificate(config, state['mountStateDir'], state['commonName'], state['fsId'],
629
                                                             credentials_source, ap_state, state['region'], base_path=base_path)
630
631
632
633
634
    if updated_certificate_creation_time:
        state['certificateCreationTime'] = updated_certificate_creation_time
        rewrite_state_file(state, state_file_dir, state_file)

        # send SIGHUP to force a reload of the configuration file to trigger the stunnel process to notice the new certificate
635
636
637
638
        pid = state.get('pid')
        if is_pid_running(pid):
            process_group = os.getpgid(pid)
            logging.info('SIGHUP signal to stunnel. PID: %d, group ID: %s', pid, process_group)
639
            os.killpg(process_group, SIGHUP)
640
        else:
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
            logging.warning('TLS tunnel is not running for %s', state_file)


def create_required_directory(config, directory):
    mode = 0o750
    try:
        mode_str = config.get(CONFIG_SECTION, 'state_file_dir_mode')
        try:
            mode = int(mode_str, 8)
        except ValueError:
            logging.warning('Bad state_file_dir_mode "%s" in config file "%s"', mode_str, CONFIG_FILE)
    except NoOptionError:
        pass

    try:
        os.makedirs(directory, mode)
        logging.debug('Expected %s not found, recreating asset', directory)
    except OSError as e:
        if errno.EEXIST != e.errno or not os.path.isdir(directory):
            raise


663
664
665
666
667
668
669
670
def get_client_info(config):
    client_info = {}

    # source key/value pair in config file
    if config.has_option(CLIENT_INFO_SECTION, 'source'):
        client_source = config.get(CLIENT_INFO_SECTION, 'source')
        if 0 < len(client_source) <= CLIENT_SOURCE_STR_LEN_LIMIT:
            client_info['source'] = client_source
671
672
673
674
    if not client_info.get('source'):
        client_info['source'] = DEFAULT_UNKNOWN_VALUE

    client_info['efs_utils_version'] = VERSION
675
676
677
678

    return client_info


679
680
def recreate_certificate(config, mount_name, common_name, fs_id, credentials_source, ap_id, region,
                         base_path=STATE_FILE_DIR):
681
682
683
684
685
686
687
688
689
690
691
692
    current_time = get_utc_now()
    tls_paths = tls_paths_dictionary(mount_name, base_path)

    certificate_config = os.path.join(tls_paths['mount_dir'], 'config.conf')
    certificate_signing_request = os.path.join(tls_paths['mount_dir'], 'request.csr')
    certificate = os.path.join(tls_paths['mount_dir'], 'certificate.pem')

    ca_dirs_check(config, tls_paths['database_dir'], tls_paths['certs_dir'])
    ca_supporting_files_check(tls_paths['index'], tls_paths['index_attr'], tls_paths['serial'], tls_paths['rand'])

    private_key = check_and_create_private_key(base_path)

693
    if credentials_source:
694
695
696
        public_key = os.path.join(tls_paths['mount_dir'], 'publicKey.pem')
        create_public_key(private_key, public_key)

697
    client_info = get_client_info(config)
698
    config_body = create_ca_conf(certificate_config, common_name, tls_paths['mount_dir'], private_key, current_time, region,
699
                                 fs_id, credentials_source, ap_id=ap_id, client_info=client_info)
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770

    if not config_body:
        logging.error('Cannot recreate self-signed certificate')
        return None

    create_certificate_signing_request(certificate_config, private_key, certificate_signing_request)

    not_before = get_certificate_timestamp(current_time, minutes=-NOT_BEFORE_MINS)
    not_after = get_certificate_timestamp(current_time, hours=NOT_AFTER_HOURS)

    cmd = 'openssl ca -startdate %s -enddate %s -selfsign -batch -notext -config %s -in %s -out %s' % \
          (not_before, not_after, certificate_config, certificate_signing_request, certificate)
    subprocess_call(cmd, 'Failed to create self-signed client-side certificate')
    return current_time.strftime(CERT_DATETIME_FORMAT)


def get_private_key_path():
    """Wrapped for mocking purposes in unit tests"""
    return PRIVATE_KEY_FILE


def check_and_create_private_key(base_path=STATE_FILE_DIR):
    # Creating RSA private keys is slow, so we will create one private key and allow mounts to share it.
    # This means, however, that we have to include a locking mechanism to ensure that the private key is
    # atomically created, as mounts occurring in parallel may try to create the key simultaneously.
    # The key should have been created during mounting, but the watchdog will recreate the private key if
    # it is missing.
    key = get_private_key_path()

    @contextmanager
    def open_lock_file():
        lock_file = os.path.join(base_path, 'efs-utils-lock')
        f = os.open(lock_file, os.O_CREAT | os.O_DSYNC | os.O_EXCL | os.O_RDWR)
        try:
            lock_file_contents = 'PID: %s' % os.getpid()
            os.write(f, lock_file_contents.encode('utf-8'))
            yield f
        finally:
            os.close(f)
            os.remove(lock_file)

    def do_with_lock(function):
        while True:
            try:
                with open_lock_file():
                    return function()
            except OSError as e:
                if e.errno == errno.EEXIST:
                    logging.info('Failed to take out private key creation lock, sleeping 50 ms')
                    time.sleep(0.05)
                else:
                    raise

    def generate_key():
        if os.path.isfile(key):
            return

        cmd = 'openssl genpkey -algorithm RSA -out %s -pkeyopt rsa_keygen_bits:3072' % key
        subprocess_call(cmd, 'Failed to create private key')
        read_only_mode = 0o400
        os.chmod(key, read_only_mode)

    do_with_lock(generate_key)
    return key


def create_certificate_signing_request(config_path, key_path, csr_path):
    cmd = 'openssl req -new -config %s -key %s -out %s' % (config_path, key_path, csr_path)
    subprocess_call(cmd, 'Failed to create certificate signing request (csr)')


771
def create_ca_conf(config_path, common_name, directory, private_key, date, region, fs_id, credentials_source,
772
                   ap_id=None, client_info=None):
773
774
    """Populate ca/req configuration file with fresh configurations at every mount since SigV4 signature can change"""
    public_key_path = os.path.join(directory, 'publicKey.pem')
775
    security_credentials = get_aws_security_credentials(credentials_source) if credentials_source else ''
776

777
778
    if credentials_source and security_credentials is None:
        logging.error('Failed to retrieve AWS security credentials using lookup method: %s', credentials_source)
779
780
        return None

781
    ca_extension_body = ca_extension_builder(ap_id, security_credentials, fs_id, client_info)
782
783
784
785
786
    efs_client_auth_body = efs_client_auth_builder(public_key_path, security_credentials['AccessKeyId'],
                                                   security_credentials['SecretAccessKey'], date, region, fs_id,
                                                   security_credentials['Token']) if credentials_source else ''
    if credentials_source and not efs_client_auth_body:
        logging.error('Failed to create AWS SigV4 signature section for OpenSSL config. Public Key path: %s', public_key_path)
787
        return None
788
789
790
    efs_client_info_body = efs_client_info_builder(client_info) if client_info else ''
    full_config_body = CA_CONFIG_BODY % (directory, private_key, common_name, ca_extension_body,
                                         efs_client_auth_body, efs_client_info_body)
791
792
793
794
795
796
797

    with open(config_path, 'w') as f:
        f.write(full_config_body)

    return full_config_body


798
def ca_extension_builder(ap_id, security_credentials, fs_id, client_info):
799
800
801
    ca_extension_str = '[ v3_ca ]\nsubjectKeyIdentifier = hash'
    if ap_id:
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.1 = ASN1:UTF8String:' + ap_id
802
    if security_credentials:
803
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.2 = ASN1:SEQUENCE:efs_client_auth'
804
805

    ca_extension_str += '\n1.3.6.1.4.1.4843.7.3 = ASN1:UTF8String:' + fs_id
806
807
    if client_info:
        ca_extension_str += '\n1.3.6.1.4.1.4843.7.4 = ASN1:SEQUENCE:efs_client_info'
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831

    return ca_extension_str


def efs_client_auth_builder(public_key_path, access_key_id, secret_access_key, date, region, fs_id, session_token=None):
    public_key_hash = get_public_key_sha1(public_key_path)

    if not public_key_hash:
        return None

    canonical_request = create_canonical_request(public_key_hash, date, access_key_id, region, fs_id, session_token)
    string_to_sign = create_string_to_sign(canonical_request, date, region)
    signature = calculate_signature(string_to_sign, date, secret_access_key, region)
    efs_client_auth_str = '[ efs_client_auth ]'
    efs_client_auth_str += '\naccessKeyId = UTF8String:' + access_key_id
    efs_client_auth_str += '\nsignature = OCTETSTRING:' + signature
    efs_client_auth_str += '\nsigv4DateTime = UTCTIME:' + date.strftime(CERT_DATETIME_FORMAT)

    if session_token:
        efs_client_auth_str += '\nsessionToken = EXPLICIT:0,UTF8String:' + session_token

    return efs_client_auth_str


832
833
834
835
836
837
838
def efs_client_info_builder(client_info):
    efs_client_info_str = '[ efs_client_info ]'
    for key, value in client_info.items():
        efs_client_info_str += '\n%s = UTF8String: %s' % (key, value)
    return efs_client_info_str


839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
def create_public_key(private_key, public_key):
    cmd = 'openssl rsa -in %s -outform PEM -pubout -out %s' % (private_key, public_key)
    subprocess_call(cmd, 'Failed to create public key')


def subprocess_call(cmd, error_message):
    """Helper method to run shell openssl command and to handle response error messages"""
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
    (output, err) = process.communicate()
    rc = process.poll()
    if rc != 0:
        logging.debug('%s. Command %s failed, rc=%s, stdout="%s", stderr="%s"', error_message, cmd, rc, output, err)
    else:
        return output, err


def ca_dirs_check(config, database_dir, certs_dir):
    """Check if mount's database and certs directories exist and if not, create directories (also create all intermediate
    directories if they don't exist)."""
    if not os.path.exists(database_dir):
        create_required_directory(config, database_dir)
    if not os.path.exists(certs_dir):
        create_required_directory(config, certs_dir)


def ca_supporting_files_check(index_path, index_attr_path, serial_path, rand_path):
    """Create all supporting openssl ca and req files if they're not present in their respective directories"""
    def _recreate_file_warning(path):
        logging.warning('Expected %s not found, recreating file', path)

    if not os.path.isfile(index_path):
        open(index_path, 'w').close()
        _recreate_file_warning(index_path)
    if not os.path.isfile(index_attr_path):
        with open(index_attr_path, 'w+') as f:
            f.write('unique_subject = no')
        _recreate_file_warning(index_attr_path)
    if not os.path.isfile(serial_path):
        with open(serial_path, 'w+') as f:
            f.write('00')
        _recreate_file_warning(serial_path)
    if not os.path.isfile(rand_path):
        open(rand_path, 'w').close()
        _recreate_file_warning(rand_path)


def tls_paths_dictionary(mount_name, base_path=STATE_FILE_DIR):
    tls_dict = {
        'mount_dir': os.path.join(base_path, mount_name),
        'database_dir': os.path.join(base_path, mount_name, 'database'),
        'certs_dir': os.path.join(base_path, mount_name, 'certs'),
        'index': os.path.join(base_path, mount_name, 'database/index.txt'),
        'index_attr': os.path.join(base_path, mount_name, 'database/index.txt.attr'),
        'serial': os.path.join(base_path, mount_name, 'database/serial'),
        'rand': os.path.join(base_path, mount_name, 'database/.rand')
    }

    return tls_dict


def get_public_key_sha1(public_key):
    # truncating public key to remove the header and footer '-----(BEGIN|END) PUBLIC KEY-----'
    with open(public_key, 'r') as f:
        lines = f.readlines()
        lines = lines[1:-1]

    key = ''.join(lines)
    key = bytearray(base64.b64decode(key))

    # Parse the public key to pull out the actual key material by looking for the key BIT STRING
    # Example:
    #     0:d=0  hl=4 l= 418 cons: SEQUENCE
    #     4:d=1  hl=2 l=  13 cons: SEQUENCE
    #     6:d=2  hl=2 l=   9 prim: OBJECT            :rsaEncryption
    #    17:d=2  hl=2 l=   0 prim: NULL
    #    19:d=1  hl=4 l= 399 prim: BIT STRING
    cmd = 'openssl asn1parse -inform PEM -in %s' % public_key
    output, err = subprocess_call(cmd, 'Unable to ASN1 parse public key file, %s, correctly' % public_key)

    key_line = ''
    for line in output.splitlines():
        if 'BIT STRING' in line.decode('utf-8'):
            key_line = line.decode('utf-8')

    if not key_line:
        logging.error('Public key file, %s, is incorrectly formatted', public_key)
        return None

    key_line = key_line.replace(' ', '')

    # DER encoding TLV (Tag, Length, Value)
    # - the first octet (byte) is the tag (type)
    # - the next octets are the length - "definite form"
    #   - the first octet always has the high order bit (8) set to 1
    #   - the remaining 127 bits are used to encode the number of octets that follow
    #   - the following octets encode, as big-endian, the length (which may be 0) as a number of octets
    # - the remaining octets are the "value" aka content
    #
    # For a BIT STRING, the first octet of the value is used to signify the number of unused bits that exist in the last
    # content byte. Note that this is explicitly excluded from the SubjectKeyIdentifier hash, per
    # https://tools.ietf.org/html/rfc5280#section-4.2.1.2
    #
    # Example:
    #   0382018f00...<subjectPublicKey>
    #   - 03 - BIT STRING tag
    #   - 82 - 2 length octets to follow (ignore high order bit)
    #   - 018f - length of 399
    #   - 00 - no unused bits in the last content byte
    offset = int(key_line.split(':')[0])
    key = key[offset:]

    num_length_octets = key[1] & 0b01111111

    # Exclude the tag (1), length (1 + num_length_octets), and number of unused bits (1)
    offset = 1 + 1 + num_length_octets + 1
    key = key[offset:]

    sha1 = hashlib.sha1()
    sha1.update(key)

    return sha1.hexdigest()


def create_canonical_request(public_key_hash, date, access_key, region, fs_id, session_token=None):
    """
    Create a Canonical Request - https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
    """
    formatted_datetime = date.strftime(SIGV4_DATETIME_FORMAT)
    credential = quote_plus(access_key + '/' + get_credential_scope(date, region))

    request = HTTP_REQUEST_METHOD + '\n'
    request += CANONICAL_URI + '\n'
    request += create_canonical_query_string(public_key_hash, credential, formatted_datetime, session_token) + '\n'
    request += CANONICAL_HEADERS % fs_id + '\n'
    request += SIGNED_HEADERS + '\n'

    sha256 = hashlib.sha256()
    sha256.update(REQUEST_PAYLOAD.encode())
    request += sha256.hexdigest()

    return request


def create_canonical_query_string(public_key_hash, credential, formatted_datetime, session_token=None):
    canonical_query_params = {
        'Action': 'Connect',
        # Public key hash is included in canonical request to tie the signature to a specific key pair to avoid replay attacks
        'PublicKeyHash': quote_plus(public_key_hash),
        'X-Amz-Algorithm': ALGORITHM,
        'X-Amz-Credential': credential,
        'X-Amz-Date': quote_plus(formatted_datetime),
        'X-Amz-Expires': 86400,
        'X-Amz-SignedHeaders': SIGNED_HEADERS,
    }

    if session_token:
        canonical_query_params['X-Amz-Security-Token'] = quote_plus(session_token)

    # Cannot use urllib.urlencode because it replaces the %s's
    return '&'.join(['%s=%s' % (k, v) for k, v in sorted(canonical_query_params.items())])


def create_string_to_sign(canonical_request, date, region):
    """
    Create a String to Sign - https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
    """
    string_to_sign = ALGORITHM + '\n'
    string_to_sign += date.strftime(SIGV4_DATETIME_FORMAT) + '\n'
    string_to_sign += get_credential_scope(date, region) + '\n'

    sha256 = hashlib.sha256()
    sha256.update(canonical_request.encode())
    string_to_sign += sha256.hexdigest()

    return string_to_sign


def calculate_signature(string_to_sign, date, secret_access_key, region):
    """
    Calculate the Signature - https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html
    """
    def _sign(key, msg):
        return hmac.new(key, msg.encode('utf-8'), hashlib.sha256)

    key_date = _sign(('AWS4' + secret_access_key).encode('utf-8'), date.strftime(DATE_ONLY_FORMAT)).digest()
    add_region = _sign(key_date, region).digest()
    add_service = _sign(add_region, SERVICE).digest()
    signing_key = _sign(add_service, 'aws4_request').digest()

    return _sign(signing_key, string_to_sign).hexdigest()


1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
def get_certificate_renewal_interval_mins(config):
    interval = DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN
    try:
        mins_from_config = config.get(CONFIG_SECTION, 'tls_cert_renewal_interval_min')
        try:
            if int(mins_from_config) > 0:
                interval = int(mins_from_config)
            else:
                logging.warning('tls_cert_renewal_interval_min value in config file "%s" is lower than 1 minute. Defaulting '
                                'to %d minutes.', CONFIG_FILE, DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN)
        except ValueError:
            logging.warning('Bad tls_cert_renewal_interval_min value, "%s", in config file "%s". Defaulting to %d minutes.',
                            mins_from_config, CONFIG_FILE, DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN)
    except NoOptionError:
        logging.warning('No tls_cert_renewal_interval_min value in config file "%s". Defaulting to %d minutes.', CONFIG_FILE,
                        DEFAULT_REFRESH_SELF_SIGNED_CERT_INTERVAL_MIN)

    return interval


1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
def get_credential_scope(date, region):
    return '/'.join([date.strftime(DATE_ONLY_FORMAT), region, SERVICE, AWS4_REQUEST])


def get_certificate_timestamp(current_time, **kwargs):
    updated_time = current_time + timedelta(**kwargs)
    return updated_time.strftime(CERT_DATETIME_FORMAT)


def get_utc_now():
    """
    Wrapped for patching purposes in unit tests
    """
    return datetime.utcnow()


1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
def check_process_name(pid):
    cmd = ['cat', '/proc/{pid}/cmdline'.format(pid=pid)]

    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
    return p.communicate()[0]


def clean_up_previous_stunnel_pids(state_file_dir=STATE_FILE_DIR):
    """
    Cleans up stunnel pids created by mount watchdog spawned by a previous efs-csi-driver after driver restart, upgrade
    or crash. This method attempts to clean PIDs from persisted state files after efs-csi-driver restart to
    ensure watchdog creates a new stunnel.
    """
    state_files = get_state_files(state_file_dir)
    logging.debug('Persisted state files in "%s": %s', state_file_dir, list(state_files.values()))

    for state_file in state_files.values():
        state_file_path = os.path.join(state_file_dir, state_file)
        with open(state_file_path) as f:
            try:
                state = json.load(f)
            except ValueError:
                logging.exception('Unable to parse json in %s', state_file_path)
                continue

            try:
                pid = state['pid']
            except KeyError:
                logging.debug('No PID found in state file %s', state_file)
                continue

            out = check_process_name(pid)

            if out and 'stunnel' in str(out):
                logging.debug('PID %s in state file %s is active. Skipping clean up', pid, state_file)
                continue

            state.pop('pid')
            logging.debug('Cleaning up pid %s in state file %s', pid, state_file)

            rewrite_state_file(state, state_file_dir, state_file)


Max Beckett's avatar
Max Beckett committed
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
def main():
    parse_arguments()
    assert_root()

    config = read_config()
    bootstrap_logging(config)

    child_procs = []

    if config.getboolean(CONFIG_SECTION, 'enabled'):
1120
        logging.info('amazon-efs-mount-watchdog, version %s, is enabled and started', VERSION)
Max Beckett's avatar
Max Beckett committed
1121
1122
1123
        poll_interval_sec = config.getint(CONFIG_SECTION, 'poll_interval_sec')
        unmount_grace_period_sec = config.getint(CONFIG_SECTION, 'unmount_grace_period_sec')

1124
1125
        clean_up_previous_stunnel_pids()

Max Beckett's avatar
Max Beckett committed
1126
        while True:
1127
            config = read_config()
1128
            check_efs_mounts(config, child_procs, unmount_grace_period_sec)
Max Beckett's avatar
Max Beckett committed
1129
1130
1131
1132
1133
1134
1135
1136
1137
            check_child_procs(child_procs)

            time.sleep(poll_interval_sec)
    else:
        logging.info('amazon-efs-mount-watchdog is not enabled')


if '__main__' == __name__:
    main()