# Copyright 2015-2017 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import csv
from datetime import datetime, timedelta
import functools
import json
import itertools
import logging
import os
import random
import re
import threading
import time
import six
import sys
from six.moves.urllib import parse as urlparse
from c7n.exceptions import ClientError, PolicyValidationError
from c7n import ipaddress, config
# Try to place nice in lambda exec environment
# where we don't require yaml
try:
    import yaml
except ImportError:  # pragma: no cover
    yaml = None
else:
    try:
        from yaml import CSafeLoader
        SafeLoader = CSafeLoader
    except ImportError:  # pragma: no cover
        try:
            from yaml import SafeLoader
        except ImportError:
            SafeLoader = None
log = logging.getLogger('custodian.utils')
[docs]class UnicodeWriter:
    """utf8 encoding csv writer."""
    def __init__(self, f, dialect=csv.excel, **kwds):
        self.writer = csv.writer(f, dialect=dialect, **kwds)
        if sys.version_info.major == 3:
            self.writerows = self.writer.writerows
            self.writerow = self.writer.writerow
[docs]    def writerow(self, row):
        self.writer.writerow([s.encode("utf-8") for s in row]) 
[docs]    def writerows(self, rows):
        for row in rows:
            self.writerow(row)  
[docs]def load_file(path, format=None, vars=None):
    if format is None:
        format = 'yaml'
        _, ext = os.path.splitext(path)
        if ext[1:] == 'json':
            format = 'json'
    with open(path) as fh:
        contents = fh.read()
        if vars:
            try:
                contents = contents.format(**vars)
            except IndexError:
                msg = 'Failed to substitute variable by positional argument.'
                raise VarsSubstitutionError(msg)
            except KeyError as e:
                msg = 'Failed to substitute variables.  KeyError on {}'.format(str(e))
                raise VarsSubstitutionError(msg)
        if format == 'yaml':
            return yaml_load(contents)
        elif format == 'json':
            return loads(contents) 
[docs]def yaml_load(value):
    if yaml is None:
        raise RuntimeError("Yaml not available")
    return yaml.load(value, Loader=SafeLoader) 
[docs]def loads(body):
    return json.loads(body) 
[docs]def dumps(data, fh=None, indent=0):
    if fh:
        return json.dump(data, fh, cls=DateTimeEncoder, indent=indent)
    else:
        return json.dumps(data, cls=DateTimeEncoder, indent=indent) 
[docs]def filter_empty(d):
    for k, v in list(d.items()):
        if not v:
            del d[k]
    return d 
[docs]def type_schema(
        type_name, inherits=None, rinherit=None,
        aliases=None, required=None, **props):
    """jsonschema generation helper
    params:
     - type_name: name of the type
     - inherits: list of document fragments that are required via anyOf[$ref]
     - rinherit: use another schema as a base for this, basically work around
                 inherits issues with additionalProperties and type enums.
     - aliases: additional names this type maybe called
     - required: list of required properties, by default 'type' is required
     - props: additional key value properties
    """
    if aliases:
        type_names = [type_name]
        type_names.extend(aliases)
    else:
        type_names = [type_name]
    if rinherit:
        s = copy.deepcopy(rinherit)
        s['properties']['type'] = {'enum': type_names}
    else:
        s = {
            'type': 'object',
            'properties': {
                'type': {'enum': type_names}}}
    # Ref based inheritance and additional properties don't mix well.
    # https://stackoverflow.com/questions/22689900/json-schema-allof-with-additionalproperties
    if not inherits:
        s['additionalProperties'] = False
    s['properties'].update(props)
    if not required:
        required = []
    if isinstance(required, list):
        required.append('type')
    s['required'] = required
    if inherits:
        extended = s
        s = {'allOf': [{'$ref': i} for i in inherits]}
        s['allOf'].append(extended)
    return s 
[docs]class DateTimeEncoder(json.JSONEncoder):
[docs]    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.isoformat()
        return json.JSONEncoder.default(self, obj)  
[docs]def group_by(resources, key):
    """Return a mapping of key value to resources with the corresponding value.
    Key may be specified as dotted form for nested dictionary lookup
    """
    resource_map = {}
    parts = key.split('.')
    for r in resources:
        v = r
        for k in parts:
            v = v.get(k)
            if not isinstance(v, dict):
                break
        resource_map.setdefault(v, []).append(r)
    return resource_map 
[docs]def chunks(iterable, size=50):
    """Break an iterable into lists of size"""
    batch = []
    for n in iterable:
        batch.append(n)
        if len(batch) % size == 0:
            yield batch
            batch = []
    if batch:
        yield batch 
[docs]def camelResource(obj):
    """Some sources from apis return lowerCased where as describe calls
    always return TitleCase, this function turns the former to the later
    """
    if not isinstance(obj, dict):
        return obj
    for k in list(obj.keys()):
        v = obj.pop(k)
        obj["%s%s" % (k[0].upper(), k[1:])] = v
        if isinstance(v, dict):
            camelResource(v)
        elif isinstance(v, list):
            list(map(camelResource, v))
    return obj 
[docs]def get_account_id_from_sts(session):
    response = session.client('sts').get_caller_identity()
    return response.get('Account') 
[docs]def get_account_alias_from_sts(session):
    response = session.client('iam').list_account_aliases()
    aliases = response.get('AccountAliases', ())
    return aliases and aliases[0] or '' 
[docs]def query_instances(session, client=None, **query):
    """Return a list of ec2 instances for the query.
    """
    if client is None:
        client = session.client('ec2')
    p = client.get_paginator('describe_instances')
    results = p.paginate(**query)
    return list(itertools.chain(
        *[r["Instances"] for r in itertools.chain(
            *[pp['Reservations'] for pp in results])])) 
CONN_CACHE = threading.local()
[docs]def local_session(factory):
    """Cache a session thread local for up to 45m"""
    factory_region = getattr(factory, 'region', 'global')
    s = getattr(CONN_CACHE, factory_region, {}).get('session')
    t = getattr(CONN_CACHE, factory_region, {}).get('time')
    n = time.time()
    if s is not None and t + (60 * 45) > n:
        return s
    s = factory()
    setattr(CONN_CACHE, factory_region, {'session': s, 'time': n})
    return s 
[docs]def reset_session_cache():
    for k in [k for k in dir(CONN_CACHE) if not k.startswith('_')]:
        setattr(CONN_CACHE, k, {}) 
[docs]def annotation(i, k):
    return i.get(k, ()) 
[docs]def set_annotation(i, k, v):
    """
    >>> x = {}
    >>> set_annotation(x, 'marker', 'a')
    >>> annotation(x, 'marker')
    ['a']
    """
    if not isinstance(i, dict):
        raise ValueError("Can only annotate dictionaries")
    if not isinstance(v, list):
        v = [v]
    if k in i:
        ev = i.get(k)
        if isinstance(ev, list):
            ev.extend(v)
    else:
        i[k] = v 
[docs]def parse_s3(s3_path):
    if not s3_path.startswith('s3://'):
        raise ValueError("invalid s3 path")
    ridx = s3_path.find('/', 5)
    if ridx == -1:
        ridx = None
    bucket = s3_path[5:ridx]
    s3_path = s3_path.rstrip('/')
    if ridx is None:
        key_prefix = ""
    else:
        key_prefix = s3_path[s3_path.find('/', 5):]
    return s3_path, bucket, key_prefix 
REGION_PARTITION_MAP = {
    'us-gov-east-1': 'aws-us-gov',
    'us-gov-west-1': 'aws-us-gov',
    'cn-north-1': 'aws-cn',
    'cn-northwest-1': 'aws-cn'
}
[docs]def generate_arn(
        service, resource, partition='aws',
        region=None, account_id=None, resource_type=None, separator='/'):
    """Generate an Amazon Resource Name.
    See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html.
    """
    if region and region in REGION_PARTITION_MAP:
        partition = REGION_PARTITION_MAP[region]
    if service == 's3':
        region = ''
    arn = 'arn:%s:%s:%s:%s:' % (
        partition, service, region if region else '', account_id if account_id else '')
    if resource_type:
        arn = arn + '%s%s%s' % (resource_type, separator, resource)
    else:
        arn = arn + resource
    return arn 
[docs]def snapshot_identifier(prefix, db_identifier):
    """Return an identifier for a snapshot of a database or cluster.
    """
    now = datetime.now()
    return '%s-%s-%s' % (prefix, db_identifier, now.strftime('%Y-%m-%d-%H-%M')) 
[docs]def get_retry(codes=(), max_attempts=8, min_delay=1, log_retries=False):
    """Decorator for retry boto3 api call on transient errors.
    https://www.awsarchitectureblog.com/2015/03/backoff.html
    https://en.wikipedia.org/wiki/Exponential_backoff
    :param codes: A sequence of retryable error codes.
    :param max_attempts: The max number of retries, by default the delay
           time is proportional to the max number of attempts.
    :param log_retries: Whether we should log retries, if specified
           specifies the level at which the retry should be logged.
    :param _max_delay: The maximum delay for any retry interval *note*
           this parameter is only exposed for unit testing, as its
           derived from the number of attempts.
    Returns a function for invoking aws client calls that
    retries on retryable error codes.
    """
    max_delay = max(min_delay, 2) ** max_attempts
    def _retry(func, *args, **kw):
        for idx, delay in enumerate(
                backoff_delays(min_delay, max_delay, jitter=True)):
            try:
                return func(*args, **kw)
            except ClientError as e:
                if e.response['Error']['Code'] not in codes:
                    raise
                elif idx == max_attempts - 1:
                    raise
                if log_retries:
                    worker_log.log(
                        log_retries,
                        "retrying %s on error:%s attempt:%d last delay:%0.2f",
                        func, e.response['Error']['Code'], idx, delay)
            time.sleep(delay)
    return _retry 
[docs]def backoff_delays(start, stop, factor=2.0, jitter=False):
    """Geometric backoff sequence w/ jitter
    """
    cur = start
    while cur <= stop:
        if jitter:
            yield cur - (cur * random.random())
        else:
            yield cur
        cur = cur * factor 
[docs]def parse_cidr(value):
    """Process cidr ranges."""
    klass = IPv4Network
    if '/' not in value:
        klass = ipaddress.ip_address
    try:
        v = klass(six.text_type(value))
    except (ipaddress.AddressValueError, ValueError):
        v = None
    return v 
[docs]class IPv4Network(ipaddress.IPv4Network):
    # Override for net 2 net containment comparison
    def __contains__(self, other):
        if other is None:
            return False
        if isinstance(other, ipaddress._BaseNetwork):
            return self.supernet_of(other)
        return super(IPv4Network, self).__contains__(other) 
worker_log = logging.getLogger('c7n.worker')
[docs]def worker(f):
    """Generic wrapper to log uncaught exceptions in a function.
    When we cross concurrent.futures executor boundaries we lose our
    traceback information, and when doing bulk operations we may tolerate
    transient failures on a partial subset. However we still want to have
    full accounting of the error in the logs, in a format that our error
    collection (cwl subscription) can still pickup.
    """
    def _f(*args, **kw):
        try:
            return f(*args, **kw)
        except Exception:
            worker_log.exception(
                'Error invoking %s',
                "%s.%s" % (f.__module__, f.__name__))
            raise
    functools.update_wrapper(_f, f)
    return _f 
# from botocore.utils avoiding runtime dependency for botocore for other providers.
# license apache 2.0
[docs]def set_value_from_jmespath(source, expression, value, is_first=True):
    # This takes a (limited) jmespath-like expression & can set a value based
    # on it.
    # Limitations:
    # * Only handles dotted lookups
    # * No offsets/wildcards/slices/etc.
    bits = expression.split('.', 1)
    current_key, remainder = bits[0], bits[1] if len(bits) > 1 else ''
    if not current_key:
        raise ValueError(expression)
    if remainder:
        if current_key not in source:
            # We've got something in the expression that's not present in the
            # source (new key). If there's any more bits, we'll set the key
            # with an empty dictionary.
            source[current_key] = {}
        return set_value_from_jmespath(
            source[current_key],
            remainder,
            value,
            is_first=False
        )
    # If we're down to a single key, set it.
    source[current_key] = value 
[docs]def parse_url_config(url):
    if url and '://' not in url:
        url += "://"
    conf = config.Bag()
    parsed = urlparse.urlparse(url)
    for k in ('scheme', 'netloc', 'path'):
        conf[k] = getattr(parsed, k)
    for k, v in urlparse.parse_qs(parsed.query).items():
        conf[k] = v[0]
    conf['url'] = url
    return conf 
[docs]class QueryParser(object):
    QuerySchema = {}
    type_name = ''
    multi_value = True
    value_key = 'Values'
[docs]    @classmethod
    def parse(cls, data):
        filters = []
        if not isinstance(data, (tuple, list)):
            raise PolicyValidationError(
                "%s Query invalid format, must be array of dicts %s" % (
                    cls.type_name,
                    data))
        for d in data:
            if not isinstance(d, dict):
                raise PolicyValidationError(
                    "%s Query Filter Invalid %s" % (cls.type_name, data))
            if "Name" not in d or cls.value_key not in d:
                raise PolicyValidationError(
                    "%s Query Filter Invalid: Missing Key or Values in %s" % (
                        cls.type_name, data))
            key = d['Name']
            values = d[cls.value_key]
            if not cls.multi_value and isinstance(values, list):
                raise PolicyValidationError(
                    "%s QUery Filter Invalid Key: Value:%s Must be single valued" % (
                        cls.type_name, key, values))
            elif not cls.multi_value:
                values = [values]
            if key not in cls.QuerySchema and not key.startswith('tag:'):
                raise PolicyValidationError(
                    "%s Query Filter Invalid Key:%s Valid: %s" % (
                        cls.type_name, key, ", ".join(cls.QuerySchema.keys())))
            vtype = cls.QuerySchema.get(key)
            if vtype is None and key.startswith('tag'):
                vtype = six.string_types
            if not isinstance(values, list):
                raise PolicyValidationError(
                    "%s Query Filter Invalid Values, must be array %s" % (
                        cls.type_name, data,))
            for v in values:
                if isinstance(vtype, tuple) and vtype != six.string_types:
                    if v not in vtype:
                        raise PolicyValidationError(
                            "%s Query Filter Invalid Value: %s Valid: %s" % (
                                cls.type_name, v, ", ".join(vtype)))
                elif not isinstance(v, vtype):
                    raise PolicyValidationError(
                        "%s Query Filter Invalid Value Type %s" % (
                            cls.type_name, data,))
            filters.append(d)
        return filters