Source code for c7n_azure.utils

# Copyright 2018 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import datetime
import enum
import hashlib
import isodate
import logging
import re
import time
import uuid

import six
from azure.graphrbac.models import GetObjectsParameters, DirectoryObject
from azure.mgmt.managementgroups import ManagementGroupsAPI
from azure.mgmt.web.models import NameValuePair
from c7n_azure import constants
from concurrent.futures import as_completed
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from netaddr import IPNetwork, IPRange

from c7n.utils import chunks
from c7n.utils import local_session


[docs]class ResourceIdParser(object):
[docs] @staticmethod def get_namespace(resource_id): return parse_resource_id(resource_id).get('namespace')
[docs] @staticmethod def get_subscription_id(resource_id): return parse_resource_id(resource_id).get('subscription')
[docs] @staticmethod def get_resource_group(resource_id): result = parse_resource_id(resource_id).get("resource_group") # parse_resource_id fails to parse resource id for resource groups if result is None: return resource_id.split('/')[4] return result
[docs] @staticmethod def get_resource_type(resource_id): parsed = parse_resource_id(resource_id) # parse_resource_id returns dictionary with "child_type_#" to represent # types sequence. "type" stores root type. child_type_keys = [k for k in parsed.keys() if k.find("child_type_") != -1] types = [parsed.get(k) for k in sorted(child_type_keys)] types.insert(0, parsed.get('type')) return '/'.join(types)
[docs] @staticmethod def get_resource_name(resource_id): return parse_resource_id(resource_id).get('resource_name')
[docs]class StringUtils(object):
[docs] @staticmethod def equal(a, b, case_insensitive=True): if isinstance(a, six.string_types) and isinstance(b, six.string_types): if case_insensitive: return a.strip().lower() == b.strip().lower() else: return a.strip() == b.strip() return False
[docs] @staticmethod def snake_to_camel(string): components = string.split('_') return components[0] + ''.join(x.title() for x in components[1:])
[docs] @staticmethod def naming_hash(val, length=8): if isinstance(val, six.string_types): val = val.encode('utf8') return hashlib.sha256(val).hexdigest().lower()[:length]
[docs]def utcnow(): """The datetime object for the current time in UTC """ return datetime.datetime.utcnow()
[docs]def now(tz=None): """The datetime object for the current time in UTC """ return datetime.datetime.now(tz=tz)
[docs]def azure_name_value_pair(name, value): return NameValuePair(**{'name': name, 'value': value})
send_logger = logging.getLogger('custodian.azure.utils.ServiceClient.send')
[docs]def custodian_azure_send_override(self, request, headers=None, content=None, **kwargs): """ Overrides ServiceClient.send() function to implement retries & log headers """ retries = 0 max_retries = 3 while retries < max_retries: response = self.orig_send(request, headers, content, **kwargs) send_logger.debug(response.status_code) for k, v in response.headers.items(): if k.startswith('x-ms-ratelimit'): send_logger.debug(k + ':' + v) # Retry codes from urllib3/util/retry.py if response.status_code in [413, 429, 503]: retry_after = None for k in response.headers.keys(): if StringUtils.equal('retry-after', k): retry_after = int(response.headers[k]) if retry_after is not None and retry_after < constants.DEFAULT_MAX_RETRY_AFTER: send_logger.warning('Received retriable error code %i. Retry-After: %i' % (response.status_code, retry_after)) time.sleep(retry_after) retries += 1 else: send_logger.error("Received throttling error, retry time is %i" "(retry only if < %i seconds)." % (retry_after or 0, constants.DEFAULT_MAX_RETRY_AFTER)) break else: break return response
[docs]class ThreadHelper: disable_multi_threading = False
[docs] @staticmethod def execute_in_parallel(resources, event, execution_method, executor_factory, log, max_workers=constants.DEFAULT_MAX_THREAD_WORKERS, chunk_size=constants.DEFAULT_CHUNK_SIZE): futures = [] results = [] exceptions = [] if ThreadHelper.disable_multi_threading: try: result = execution_method(resources, event) if result: results.extend(result) except Exception as e: exceptions.append(e) else: with executor_factory(max_workers=max_workers) as w: for resource_set in chunks(resources, chunk_size): futures.append(w.submit(execution_method, resource_set, event)) for f in as_completed(futures): if f.exception(): log.error( "Execution failed with error: %s" % f.exception()) exceptions.append(f.exception()) else: result = f.result() if result: results.extend(result) return results, list(set(exceptions))
[docs]class Math(object):
[docs] @staticmethod def mean(numbers): clean_numbers = [e for e in numbers if e is not None] return float(sum(clean_numbers)) / max(len(clean_numbers), 1)
[docs] @staticmethod def sum(numbers): clean_numbers = [e for e in numbers if e is not None] return float(sum(clean_numbers))
[docs]class GraphHelper(object): log = logging.getLogger('custodian.azure.utils.GraphHelper')
[docs] @staticmethod def get_principal_dictionary(graph_client, object_ids, raise_on_graph_call_error=False): """Retrieves Azure AD Objects for corresponding object ids passed. :param graph_client: A client for Microsoft Graph. :param object_ids: The object ids to retrieve Azure AD objects for. :param raise_on_graph_call_error: A boolean indicate whether an error should be raised if the underlying Microsoft Graph call fails. :return: A dictionary keyed by object id with the Azure AD object as the value. Note: empty Azure AD objects could be returned if not found in the graph. """ if not object_ids: return {} object_params = GetObjectsParameters( include_directory_object_references=True, object_ids=object_ids) principal_dics = {object_id: DirectoryObject() for object_id in object_ids} aad_objects = graph_client.objects.get_objects_by_object_ids(object_params) try: for aad_object in aad_objects: principal_dics[aad_object.object_id] = aad_object except CloudError as e: if e.status_code in [403, 401]: GraphHelper.log.warning( 'Credentials not authorized for access to read from Microsoft Graph. \n ' 'Can not query on principalName, displayName, or aadType. \n') else: GraphHelper.log.error( 'Exception in call to Microsoft Graph. \n ' 'Can not query on principalName, displayName, or aadType. \n' 'Error: {0}'.format(e)) if raise_on_graph_call_error: raise return principal_dics
[docs] @staticmethod def get_principal_name(graph_object): """Attempts to resolve a principal name. :param graph_object: the Azure AD Graph Object :return: The resolved value or an empty string if unsuccessful. """ if hasattr(graph_object, 'user_principal_name'): return graph_object.user_principal_name elif hasattr(graph_object, 'service_principal_names'): return graph_object.service_principal_names[0] elif hasattr(graph_object, 'display_name'): return graph_object.display_name return ''
[docs]class PortsRangeHelper(object): PortsRange = collections.namedtuple('PortsRange', 'start end') @staticmethod def _get_port_range(range_str): """ Given a string with a port or port range: '80', '80-120' Returns tuple with range start and end ports: (80, 80), (80, 120) """ if range_str == '*': return PortsRangeHelper.PortsRange(start=0, end=65535) s = range_str.split('-') if len(s) == 2: return PortsRangeHelper.PortsRange(start=int(s[0]), end=int(s[1])) return PortsRangeHelper.PortsRange(start=int(s[0]), end=int(s[0])) @staticmethod def _get_string_port_ranges(ports): """ Extracts ports ranges from the string Returns an array of PortsRange tuples """ return [PortsRangeHelper._get_port_range(r) for r in ports.split(',') if r != ''] @staticmethod def _get_rule_port_ranges(rule): """ Extracts ports ranges from the NSG rule object Returns an array of PortsRange tuples """ properties = rule['properties'] if 'destinationPortRange' in properties: return [PortsRangeHelper._get_port_range(properties['destinationPortRange'])] else: return [PortsRangeHelper._get_port_range(r) for r in properties['destinationPortRanges']] @staticmethod def _port_ranges_to_set(ranges): """ Converts array of port ranges to the set of integers Example: [(10-12), (20,20)] -> {10, 11, 12, 20} """ return set([i for r in ranges for i in range(r.start, r.end + 1)])
[docs] @staticmethod def validate_ports_string(ports): """ Validate that provided string has proper port numbers: 1. port number < 65535 2. range start < range end """ pattern = re.compile('^\\d+(-\\d+)?(,\\d+(-\\d+)?)*$') if pattern.match(ports) is None: return False ranges = PortsRangeHelper._get_string_port_ranges(ports) for r in ranges: if r.start > r.end or r.start > 65535 or r.end > 65535: return False return True
[docs] @staticmethod def get_ports_set_from_string(ports): """ Convert ports range string to the set of integers Example: "10-12, 20" -> {10, 11, 12, 20} """ ranges = PortsRangeHelper._get_string_port_ranges(ports) return PortsRangeHelper._port_ranges_to_set(ranges)
[docs] @staticmethod def get_ports_set_from_rule(rule): """ Extract port ranges from NSG rule and convert it to the set of integers """ ranges = PortsRangeHelper._get_rule_port_ranges(rule) return PortsRangeHelper._port_ranges_to_set(ranges)
[docs] @staticmethod def get_ports_strings_from_list(data): """ Transform a list of port numbers to the list of strings with port ranges Example: [10, 12, 13, 14, 15] -> ['10', '12-15'] """ if len(data) == 0: return [] # Transform diff_ports list to the ranges list first = 0 result = [] for it in range(1, len(data)): if data[first] == data[it] - (it - first): continue result.append(PortsRangeHelper.PortsRange(start=data[first], end=data[it - 1])) first = it # Update tuples with strings, representing ranges result.append(PortsRangeHelper.PortsRange(start=data[first], end=data[-1])) result = [str(x.start) if x.start == x.end else "%i-%i" % (x.start, x.end) for x in result] return result
[docs] @staticmethod def build_ports_dict(nsg, direction_key, ip_protocol): """ Build entire ports array filled with True (Allow), False (Deny) and None(default - Deny) based on the provided Network Security Group object, direction and protocol. """ rules = nsg['properties']['securityRules'] rules = sorted(rules, key=lambda k: k['properties']['priority']) ports = {} for rule in rules: # Skip rules with different direction if not StringUtils.equal(direction_key, rule['properties']['direction']): continue # Check the protocol: possible values are 'TCP', 'UDP', '*' (both) # Skip only if rule and ip_protocol are 'TCP'/'UDP' pair. protocol = rule['properties']['protocol'] if not StringUtils.equal(protocol, "*") and \ not StringUtils.equal(ip_protocol, "*") and \ not StringUtils.equal(protocol, ip_protocol): continue IsAllowed = StringUtils.equal(rule['properties']['access'], 'allow') ports_set = PortsRangeHelper.get_ports_set_from_rule(rule) for p in ports_set: if p not in ports: ports[p] = IsAllowed return ports
[docs]class IpRangeHelper(object):
[docs] @staticmethod def parse_ip_ranges(data, key): ''' Parses IP range or CIDR mask. :param data: Dictionary where to look for the value. :param key: Key for the value to be parsed. :return: Set of IP ranges and networks. ''' if key not in data: return None ranges = [[s.strip() for s in r.split('-')] for r in data[key]] result = set() for r in ranges: if len(r) > 2: raise Exception('Invalid range. Use x.x.x.x-y.y.y.y or x.x.x.x or x.x.x.x/y.') result.add(IPRange(*r) if len(r) == 2 else IPNetwork(r[0])) return result
[docs]class AppInsightsHelper(object): log = logging.getLogger('custodian.azure.utils.AppInsightsHelper')
[docs] @staticmethod def get_instrumentation_key(url): data = url.split('//')[1] try: uuid.UUID(data) except ValueError: values = data.split('/') if len(values) != 2: AppInsightsHelper.log.warning("Bad format: '%s'" % url) return AppInsightsHelper._get_instrumentation_key(values[0], values[1]) return data
@staticmethod def _get_instrumentation_key(resource_group_name, resource_name): from .session import Session s = local_session(Session) client = s.client('azure.mgmt.applicationinsights.ApplicationInsightsManagementClient') try: insights = client.components.get(resource_group_name, resource_name) return insights.instrumentation_key except Exception: AppInsightsHelper.log.warning("Failed to retrieve App Insights instrumentation key." "Resource Group name: %s, App Insights name: %s" % (resource_group_name, resource_name)) return ''
[docs]class ManagedGroupHelper(object):
[docs] @staticmethod def get_subscriptions_list(managed_resource_group, credentials): client = ManagementGroupsAPI(credentials) entities = client.entities.list(filter='name eq \'%s\'' % managed_resource_group) return [e.name for e in entities if e.type == '/subscriptions']
[docs]def generate_key_vault_url(name): return constants.TEMPLATE_KEYVAULT_URL.format(name)
[docs]class RetentionPeriod(object): PATTERN = re.compile("^P([1-9][0-9]*)([DWMY])$")
[docs] @enum.unique class Units(enum.Enum): day = ('day', 'D') days = ('days', 'D') week = ('week', 'W') weeks = ('weeks', 'W') month = ('month', 'M') months = ('months', 'M') year = ('year', 'Y') years = ('years', 'Y') def __init__(self, str_value, iso8601_symbol): self.str_value = str_value self.iso8601_symbol = iso8601_symbol def __str__(self): return self.str_value
[docs] @staticmethod def duration_from_period_and_units(period, retention_period_unit): iso8601_str = "P{}{}".format(period, retention_period_unit.iso8601_symbol) duration = isodate.parse_duration(iso8601_str) return duration
[docs] @staticmethod def parse_iso8601_retention_period(iso8601_retention_period): """ A simplified iso8601 duration parser that only accepts one duration designator. """ match = re.match(RetentionPeriod.PATTERN, iso8601_retention_period) if match is None: raise ValueError("Invalid iso8601_retention_period: {}. " "This parser only accepts a single duration designator." .format(iso8601_retention_period)) period = int(match.group(1)) iso8601_symbol = match.group(2) units = next(units for units in RetentionPeriod.Units if units.iso8601_symbol == iso8601_symbol) return period, units