Source code for c7n.utils

# Copyright 2015-2017 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import csv
from datetime import datetime, timedelta
import functools
import json
import itertools
import logging
import os
import random
import re
import threading
import time
import six
import sys

from six.moves.urllib import parse as urlparse

from c7n.exceptions import ClientError, PolicyValidationError
from c7n import ipaddress, config

# Try to place nice in lambda exec environment
# where we don't require yaml
try:
    import yaml
except ImportError:  # pragma: no cover
    yaml = None
else:
    try:
        from yaml import CSafeLoader
        SafeLoader = CSafeLoader
    except ImportError:  # pragma: no cover
        try:
            from yaml import SafeLoader
        except ImportError:
            SafeLoader = None

log = logging.getLogger('custodian.utils')


[docs]class UnicodeWriter: """utf8 encoding csv writer.""" def __init__(self, f, dialect=csv.excel, **kwds): self.writer = csv.writer(f, dialect=dialect, **kwds) if sys.version_info.major == 3: self.writerows = self.writer.writerows self.writerow = self.writer.writerow
[docs] def writerow(self, row): self.writer.writerow([s.encode("utf-8") for s in row])
[docs] def writerows(self, rows): for row in rows: self.writerow(row)
[docs]class VarsSubstitutionError(Exception): pass
[docs]def load_file(path, format=None, vars=None): if format is None: format = 'yaml' _, ext = os.path.splitext(path) if ext[1:] == 'json': format = 'json' with open(path) as fh: contents = fh.read() if vars: try: contents = contents.format(**vars) except IndexError: msg = 'Failed to substitute variable by positional argument.' raise VarsSubstitutionError(msg) except KeyError as e: msg = 'Failed to substitute variables. KeyError on {}'.format(str(e)) raise VarsSubstitutionError(msg) if format == 'yaml': return yaml_load(contents) elif format == 'json': return loads(contents)
[docs]def yaml_load(value): if yaml is None: raise RuntimeError("Yaml not available") return yaml.load(value, Loader=SafeLoader)
[docs]def loads(body): return json.loads(body)
[docs]def dumps(data, fh=None, indent=0): if fh: return json.dump(data, fh, cls=DateTimeEncoder, indent=indent) else: return json.dumps(data, cls=DateTimeEncoder, indent=indent)
[docs]def format_event(evt): return json.dumps(evt, indent=2)
[docs]def filter_empty(d): for k, v in list(d.items()): if not v: del d[k] return d
[docs]def type_schema( type_name, inherits=None, rinherit=None, aliases=None, required=None, **props): """jsonschema generation helper params: - type_name: name of the type - inherits: list of document fragments that are required via anyOf[$ref] - rinherit: use another schema as a base for this, basically work around inherits issues with additionalProperties and type enums. - aliases: additional names this type maybe called - required: list of required properties, by default 'type' is required - props: additional key value properties """ if aliases: type_names = [type_name] type_names.extend(aliases) else: type_names = [type_name] if rinherit: s = copy.deepcopy(rinherit) s['properties']['type'] = {'enum': type_names} else: s = { 'type': 'object', 'properties': { 'type': {'enum': type_names}}} # Ref based inheritance and additional properties don't mix well. # https://stackoverflow.com/questions/22689900/json-schema-allof-with-additionalproperties if not inherits: s['additionalProperties'] = False s['properties'].update(props) if not required: required = [] if isinstance(required, list): required.append('type') s['required'] = required if inherits: extended = s s = {'allOf': [{'$ref': i} for i in inherits]} s['allOf'].append(extended) return s
[docs]class DateTimeEncoder(json.JSONEncoder):
[docs] def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return json.JSONEncoder.default(self, obj)
[docs]def group_by(resources, key): """Return a mapping of key value to resources with the corresponding value. Key may be specified as dotted form for nested dictionary lookup """ resource_map = {} parts = key.split('.') for r in resources: v = r for k in parts: v = v.get(k) if not isinstance(v, dict): break resource_map.setdefault(v, []).append(r) return resource_map
[docs]def chunks(iterable, size=50): """Break an iterable into lists of size""" batch = [] for n in iterable: batch.append(n) if len(batch) % size == 0: yield batch batch = [] if batch: yield batch
[docs]def camelResource(obj): """Some sources from apis return lowerCased where as describe calls always return TitleCase, this function turns the former to the later """ if not isinstance(obj, dict): return obj for k in list(obj.keys()): v = obj.pop(k) obj["%s%s" % (k[0].upper(), k[1:])] = v if isinstance(v, dict): camelResource(v) elif isinstance(v, list): list(map(camelResource, v)) return obj
[docs]def get_account_id_from_sts(session): response = session.client('sts').get_caller_identity() return response.get('Account')
[docs]def get_account_alias_from_sts(session): response = session.client('iam').list_account_aliases() aliases = response.get('AccountAliases', ()) return aliases and aliases[0] or ''
[docs]def query_instances(session, client=None, **query): """Return a list of ec2 instances for the query. """ if client is None: client = session.client('ec2') p = client.get_paginator('describe_instances') results = p.paginate(**query) return list(itertools.chain( *[r["Instances"] for r in itertools.chain( *[pp['Reservations'] for pp in results])]))
CONN_CACHE = threading.local()
[docs]def local_session(factory): """Cache a session thread local for up to 45m""" factory_region = getattr(factory, 'region', 'global') s = getattr(CONN_CACHE, factory_region, {}).get('session') t = getattr(CONN_CACHE, factory_region, {}).get('time') n = time.time() if s is not None and t + (60 * 45) > n: return s s = factory() setattr(CONN_CACHE, factory_region, {'session': s, 'time': n}) return s
[docs]def reset_session_cache(): for k in [k for k in dir(CONN_CACHE) if not k.startswith('_')]: setattr(CONN_CACHE, k, {})
[docs]def annotation(i, k): return i.get(k, ())
[docs]def set_annotation(i, k, v): """ >>> x = {} >>> set_annotation(x, 'marker', 'a') >>> annotation(x, 'marker') ['a'] """ if not isinstance(i, dict): raise ValueError("Can only annotate dictionaries") if not isinstance(v, list): v = [v] if k in i: ev = i.get(k) if isinstance(ev, list): ev.extend(v) else: i[k] = v
[docs]def parse_s3(s3_path): if not s3_path.startswith('s3://'): raise ValueError("invalid s3 path") ridx = s3_path.find('/', 5) if ridx == -1: ridx = None bucket = s3_path[5:ridx] s3_path = s3_path.rstrip('/') if ridx is None: key_prefix = "" else: key_prefix = s3_path[s3_path.find('/', 5):] return s3_path, bucket, key_prefix
REGION_PARTITION_MAP = { 'us-gov-east-1': 'aws-us-gov', 'us-gov-west-1': 'aws-us-gov', 'cn-north-1': 'aws-cn', 'cn-northwest-1': 'aws-cn' }
[docs]def generate_arn( service, resource, partition='aws', region=None, account_id=None, resource_type=None, separator='/'): """Generate an Amazon Resource Name. See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html. """ if region and region in REGION_PARTITION_MAP: partition = REGION_PARTITION_MAP[region] if service == 's3': region = '' arn = 'arn:%s:%s:%s:%s:' % ( partition, service, region if region else '', account_id if account_id else '') if resource_type: arn = arn + '%s%s%s' % (resource_type, separator, resource) else: arn = arn + resource return arn
[docs]def snapshot_identifier(prefix, db_identifier): """Return an identifier for a snapshot of a database or cluster. """ now = datetime.now() return '%s-%s-%s' % (prefix, db_identifier, now.strftime('%Y-%m-%d-%H-%M'))
[docs]def get_retry(codes=(), max_attempts=8, min_delay=1, log_retries=False): """Decorator for retry boto3 api call on transient errors. https://www.awsarchitectureblog.com/2015/03/backoff.html https://en.wikipedia.org/wiki/Exponential_backoff :param codes: A sequence of retryable error codes. :param max_attempts: The max number of retries, by default the delay time is proportional to the max number of attempts. :param log_retries: Whether we should log retries, if specified specifies the level at which the retry should be logged. :param _max_delay: The maximum delay for any retry interval *note* this parameter is only exposed for unit testing, as its derived from the number of attempts. Returns a function for invoking aws client calls that retries on retryable error codes. """ max_delay = max(min_delay, 2) ** max_attempts def _retry(func, *args, **kw): for idx, delay in enumerate( backoff_delays(min_delay, max_delay, jitter=True)): try: return func(*args, **kw) except ClientError as e: if e.response['Error']['Code'] not in codes: raise elif idx == max_attempts - 1: raise if log_retries: worker_log.log( log_retries, "retrying %s on error:%s attempt:%d last delay:%0.2f", func, e.response['Error']['Code'], idx, delay) time.sleep(delay) return _retry
[docs]def backoff_delays(start, stop, factor=2.0, jitter=False): """Geometric backoff sequence w/ jitter """ cur = start while cur <= stop: if jitter: yield cur - (cur * random.random()) else: yield cur cur = cur * factor
[docs]def parse_cidr(value): """Process cidr ranges.""" klass = IPv4Network if '/' not in value: klass = ipaddress.ip_address try: v = klass(six.text_type(value)) except (ipaddress.AddressValueError, ValueError): v = None return v
[docs]class IPv4Network(ipaddress.IPv4Network): # Override for net 2 net containment comparison def __contains__(self, other): if other is None: return False if isinstance(other, ipaddress._BaseNetwork): return self.supernet_of(other) return super(IPv4Network, self).__contains__(other)
worker_log = logging.getLogger('c7n.worker')
[docs]def worker(f): """Generic wrapper to log uncaught exceptions in a function. When we cross concurrent.futures executor boundaries we lose our traceback information, and when doing bulk operations we may tolerate transient failures on a partial subset. However we still want to have full accounting of the error in the logs, in a format that our error collection (cwl subscription) can still pickup. """ def _f(*args, **kw): try: return f(*args, **kw) except Exception: worker_log.exception( 'Error invoking %s', "%s.%s" % (f.__module__, f.__name__)) raise functools.update_wrapper(_f, f) return _f
[docs]def reformat_schema(model): """ Reformat schema to be in a more displayable format. """ if not hasattr(model, 'schema'): return "Model '{}' does not have a schema".format(model) if 'properties' not in model.schema: return "Schema in unexpected format." ret = copy.deepcopy(model.schema['properties']) if 'type' in ret: del(ret['type']) for key in model.schema.get('required', []): if key in ret: ret[key]['required'] = True return ret
# from botocore.utils avoiding runtime dependency for botocore for other providers. # license apache 2.0
[docs]def set_value_from_jmespath(source, expression, value, is_first=True): # This takes a (limited) jmespath-like expression & can set a value based # on it. # Limitations: # * Only handles dotted lookups # * No offsets/wildcards/slices/etc. bits = expression.split('.', 1) current_key, remainder = bits[0], bits[1] if len(bits) > 1 else '' if not current_key: raise ValueError(expression) if remainder: if current_key not in source: # We've got something in the expression that's not present in the # source (new key). If there's any more bits, we'll set the key # with an empty dictionary. source[current_key] = {} return set_value_from_jmespath( source[current_key], remainder, value, is_first=False ) # If we're down to a single key, set it. source[current_key] = value
[docs]def format_string_values(obj, err_fallback=(IndexError, KeyError), *args, **kwargs): """ Format all string values in an object. Return the updated object """ if isinstance(obj, dict): new = {} for key in obj.keys(): new[key] = format_string_values(obj[key], *args, **kwargs) return new elif isinstance(obj, list): new = [] for item in obj: new.append(format_string_values(item, *args, **kwargs)) return new elif isinstance(obj, six.string_types): try: return obj.format(*args, **kwargs) except err_fallback: return obj else: return obj
[docs]def parse_url_config(url): if url and '://' not in url: url += "://" conf = config.Bag() parsed = urlparse.urlparse(url) for k in ('scheme', 'netloc', 'path'): conf[k] = getattr(parsed, k) for k, v in urlparse.parse_qs(parsed.query).items(): conf[k] = v[0] conf['url'] = url return conf
[docs]class FormatDate(object): """a datetime wrapper with extended pyformat syntax""" date_increment = re.compile(r'\+[0-9]+[Mdh]') def __init__(self, d=None): self._d = d
[docs] @classmethod def utcnow(cls): return cls(datetime.utcnow())
def __getattr__(self, k): return getattr(self._d, k) def __format__(self, fmt=None): d = self._d increments = self.date_increment.findall(fmt) for i in increments: p = {} if i[-1] == 'M': p['minutes'] = float(i[1:-1]) if i[-1] == 'h': p['hours'] = float(i[1:-1]) if i[-1] == 'd': p['days'] = float(i[1:-1]) d = d + timedelta(**p) if increments: fmt = self.date_increment.sub("", fmt) return d.__format__(fmt)
[docs]class QueryParser(object): QuerySchema = {} type_name = '' multi_value = True value_key = 'Values'
[docs] @classmethod def parse(cls, data): filters = [] if not isinstance(data, (tuple, list)): raise PolicyValidationError( "%s Query invalid format, must be array of dicts %s" % ( cls.type_name, data)) for d in data: if not isinstance(d, dict): raise PolicyValidationError( "%s Query Filter Invalid %s" % (cls.type_name, data)) if "Name" not in d or cls.value_key not in d: raise PolicyValidationError( "%s Query Filter Invalid: Missing Key or Values in %s" % ( cls.type_name, data)) key = d['Name'] values = d[cls.value_key] if not cls.multi_value and isinstance(values, list): raise PolicyValidationError( "%s QUery Filter Invalid Key: Value:%s Must be single valued" % ( cls.type_name, key, values)) elif not cls.multi_value: values = [values] if key not in cls.QuerySchema and not key.startswith('tag:'): raise PolicyValidationError( "%s Query Filter Invalid Key:%s Valid: %s" % ( cls.type_name, key, ", ".join(cls.QuerySchema.keys()))) vtype = cls.QuerySchema.get(key) if vtype is None and key.startswith('tag'): vtype = six.string_types if not isinstance(values, list): raise PolicyValidationError( "%s Query Filter Invalid Values, must be array %s" % ( cls.type_name, data,)) for v in values: if isinstance(vtype, tuple) and vtype != six.string_types: if v not in vtype: raise PolicyValidationError( "%s Query Filter Invalid Value: %s Valid: %s" % ( cls.type_name, v, ", ".join(vtype))) elif not isinstance(v, vtype): raise PolicyValidationError( "%s Query Filter Invalid Value Type %s" % ( cls.type_name, data,)) filters.append(d) return filters