Source code for c7n.resources.cloudfront

# Copyright 2016-2017 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals

import functools
import re

from c7n.actions import BaseAction
from c7n.filters import MetricsFilter, ShieldMetrics, Filter
from c7n.manager import resources
from c7n.query import QueryResourceManager, DescribeSource
from c7n.tags import universal_augment
from c7n.utils import generate_arn, local_session, type_schema, get_retry

from c7n.resources.shield import IsShieldProtected, SetShieldProtection


[docs]@resources.register('distribution') class Distribution(QueryResourceManager):
[docs] class resource_type(object): service = 'cloudfront' type = 'distribution' enum_spec = ('list_distributions', 'DistributionList.Items', None) id = 'Id' name = 'DomainName' date = 'LastModifiedTime' dimension = "DistributionId" universal_taggable = True filter_name = None config_type = "AWS::CloudFront::Distribution" # Denotes this resource type exists across regions global_resource = True
[docs] def get_arn(self, r): return r['ARN']
@property def generate_arn(self): """ Generates generic arn if ID is not already arn format. """ if self._generate_arn is None: self._generate_arn = functools.partial( generate_arn, self.get_model().service, account_id=self.account_id, resource_type=self.get_model().type, separator='/') return self._generate_arn
[docs] def get_source(self, source_type): if source_type == 'describe': return DescribeDistribution(self) return super(Distribution, self).get_source(source_type)
[docs]class DescribeDistribution(DescribeSource):
[docs] def augment(self, resources): return universal_augment(self.manager, resources)
[docs]@resources.register('streaming-distribution') class StreamingDistribution(QueryResourceManager):
[docs] class resource_type(object): service = 'cloudfront' type = 'streaming-distribution' enum_spec = ('list_streaming_distributions', 'StreamingDistributionList.Items', None) id = 'Id' name = 'DomainName' date = 'LastModifiedTime' dimension = "DistributionId" universal_taggable = True filter_name = None config_type = "AWS::CloudFront::StreamingDistribution"
[docs] def get_arn(self, r): return r['ARN']
@property def generate_arn(self): """ Generates generic arn if ID is not already arn format. """ if self._generate_arn is None: self._generate_arn = functools.partial( generate_arn, self.get_model().service, account_id=self.account_id, resource_type=self.get_model().type, separator='/') return self._generate_arn
[docs] def get_source(self, source_type): if source_type == 'describe': return DescribeStreamingDistribution(self) return super(StreamingDistribution, self).get_source(source_type)
[docs]class DescribeStreamingDistribution(DescribeSource):
[docs] def augment(self, resources): return universal_augment(self.manager, resources)
Distribution.filter_registry.register('shield-metrics', ShieldMetrics) Distribution.filter_registry.register('shield-enabled', IsShieldProtected) Distribution.action_registry.register('set-shield', SetShieldProtection)
[docs]@Distribution.filter_registry.register('metrics') @StreamingDistribution.filter_registry.register('metrics') class DistributionMetrics(MetricsFilter): """Filter cloudfront distributions based on metric values :example: .. code-block:: yaml policies: - name: cloudfront-distribution-errors resource: distribution filters: - type: metrics name: Requests value: 3 op: ge """
[docs] def get_dimensions(self, resource): return [{'Name': self.model.dimension, 'Value': resource[self.model.id]}, {'Name': 'Region', 'Value': 'Global'}]
[docs]@Distribution.filter_registry.register('waf-enabled') class IsWafEnabled(Filter): # useful primarily to use the same name across accounts, else webaclid # attribute works as well schema = type_schema( 'waf-enabled', **{ 'web-acl': {'type': 'string'}, 'state': {'type': 'boolean'}}) permissions = ('waf:ListWebACLs',)
[docs] def process(self, resources, event=None): target_acl = self.data.get('web-acl') wafs = self.manager.get_resource_manager('waf').resources() waf_name_id_map = {w['Name']: w['WebACLId'] for w in wafs} target_acl = self.data.get('web-acl') target_acl_id = waf_name_id_map.get(target_acl, target_acl) if target_acl_id and target_acl_id not in waf_name_id_map.values(): raise ValueError("invalid web acl: %s" % (target_acl_id)) state = self.data.get('state', False) results = [] for r in resources: if state and target_acl_id is None and r.get('WebACLId'): results.append(r) elif not state and target_acl_id is None and not r.get('WebACLId'): results.append(r) elif state and target_acl_id and r['WebACLId'] == target_acl_id: results.append(r) elif not state and target_acl_id and r['WebACLId'] != target_acl_id: results.append(r) return results
[docs]@Distribution.filter_registry.register('mismatch-s3-origin') class MismatchS3Origin(Filter): """Check for existence of S3 bucket referenced by Cloudfront, and verify whether owner is different from Cloudfront account owner. :example: .. code-block:: yaml policies: - name: mismatch-s3-origin resource: distribution filters: - type: mismatch-s3-origin check_custom_origins: true """ s3_prefix = re.compile(r'.*(?=\.s3(-.*)?\.amazonaws.com)') s3_suffix = re.compile(r'^([^.]+\.)?s3(-.*)?\.amazonaws.com') schema = type_schema( 'mismatch-s3-origin', check_custom_origins={'type': 'boolean'}) permissions = ('s3:ListBuckets',) retry = staticmethod(get_retry(('Throttling',)))
[docs] def is_s3_domain(self, x): bucket_match = self.s3_prefix.match(x['DomainName']) if bucket_match: return bucket_match.group() domain_match = self.s3_suffix.match(x['DomainName']) if domain_match: value = x['OriginPath'] if value.startswith('/'): value = value.replace("/", "", 1) return value return None
[docs] def process(self, resources, event=None): results = [] s3_client = local_session(self.manager.session_factory).client( 's3', region_name=self.manager.config.region) buckets = {b['Name'] for b in s3_client.list_buckets()['Buckets']} for r in resources: r['c7n:mismatched-s3-origin'] = [] for x in r['Origins']['Items']: if 'S3OriginConfig' in x: bucket_match = self.s3_prefix.match(x['DomainName']) if bucket_match: target_bucket = self.s3_prefix.match(x['DomainName']).group() elif 'CustomOriginConfig' in x and self.data.get('check_custom_origins'): target_bucket = self.is_s3_domain(x) if target_bucket is not None and target_bucket not in buckets: self.log.debug("Bucket %s not found in distribution %s hosting account." % (target_bucket, r['Id'])) r['c7n:mismatched-s3-origin'].append(target_bucket) results.append(r) return results
[docs]@Distribution.action_registry.register('set-waf') class SetWaf(BaseAction): permissions = ('cloudfront:UpdateDistribution', 'waf:ListWebACLs') schema = type_schema( 'set-waf', required=['web-acl'], **{ 'web-acl': {'type': 'string'}, 'force': {'type': 'boolean'}, 'state': {'type': 'boolean'}}) retry = staticmethod(get_retry(('Throttling',)))
[docs] def process(self, resources): wafs = self.manager.get_resource_manager('waf').resources() waf_name_id_map = {w['Name']: w['WebACLId'] for w in wafs} target_acl = self.data.get('web-acl') target_acl_id = waf_name_id_map.get(target_acl, target_acl) if target_acl_id not in waf_name_id_map.values(): raise ValueError("invalid web acl: %s" % (target_acl_id)) client = local_session(self.manager.session_factory).client( 'cloudfront') force = self.data.get('force', False) for r in resources: if r.get('WebACLId') and not force: continue if r.get('WebACLId') == target_acl_id: continue result = client.get_distribution_config(Id=r['Id']) config = result['DistributionConfig'] config['WebACLId'] = target_acl_id self.retry( client.update_distribution, Id=r['Id'], DistributionConfig=config, IfMatch=result['ETag'])
[docs]@Distribution.action_registry.register('disable') class DistributionDisableAction(BaseAction): """Action to disable a Distribution :example: .. code-block:: yaml policies: - name: distribution-delete resource: distribution filters: - type: value key: CacheBehaviors.Items[].ViewerProtocolPolicy value: allow-all op: contains actions: - type: disable """ schema = type_schema('disable') permissions = ("distribution:GetDistributionConfig", "distribution:UpdateDistribution",)
[docs] def process(self, distributions): client = local_session( self.manager.session_factory).client(self.manager.get_model().service) for d in distributions: self.process_distribution(client, d)
[docs] def process_distribution(self, client, distribution): try: res = client.get_distribution_config( Id=distribution[self.manager.get_model().id]) res['DistributionConfig']['Enabled'] = False res = client.update_distribution( Id=distribution[self.manager.get_model().id], IfMatch=res['ETag'], DistributionConfig=res['DistributionConfig'] ) except Exception as e: self.log.warning( "Exception trying to disable Distribution: %s error: %s", distribution['ARN'], e) return
[docs]@StreamingDistribution.action_registry.register('disable') class StreamingDistributionDisableAction(BaseAction): """Action to disable a Streaming Distribution :example: .. code-block:: yaml policies: - name: streaming-distribution-delete resource: streaming-distribution filters: - type: value key: S3Origin.OriginAccessIdentity value: '' actions: - type: disable """ schema = type_schema('disable') permissions = ("streaming-distribution:GetStreamingDistributionConfig", "streaming-distribution:UpdateStreamingDistribution",)
[docs] def process(self, distributions): client = local_session( self.manager.session_factory).client(self.manager.get_model().service) for d in distributions: self.process_distribution(client, d)
[docs] def process_distribution(self, client, distribution): try: res = client.get_streaming_distribution_config( Id=distribution[self.manager.get_model().id]) res['StreamingDistributionConfig']['Enabled'] = False res = client.update_streaming_distribution( Id=distribution[self.manager.get_model().id], IfMatch=res['ETag'], StreamingDistributionConfig=res['StreamingDistributionConfig'] ) except Exception as e: self.log.warning( "Exception trying to disable Distribution: %s error: %s", distribution['ARN'], e) return
[docs]@Distribution.action_registry.register('set-protocols') class DistributionSSLAction(BaseAction): """Action to set mandatory https-only on a Distribution :example: .. code-block:: yaml policies: - name: distribution-set-ssl resource: distribution filters: - type: value key: CacheBehaviors.Items[].ViewerProtocolPolicy value: allow-all op: contains actions: - type: set-protocols ViewerProtocolPolicy: https-only """ schema = { 'type': 'object', 'additionalProperties': False, 'properties': { 'type': {'enum': ['set-protocols']}, 'OriginProtocolPolicy': { 'enum': ['http-only', 'match-viewer', 'https-only'] }, 'OriginSslProtocols': { 'type': 'array', 'items': {'enum': ['SSLv3', 'TLSv1', 'TLSv1.1', 'TLSv1.2']} }, 'ViewerProtocolPolicy': { 'enum': ['allow-all', 'https-only', 'redirect-to-https'] } } } permissions = ("distribution:GetDistributionConfig", "distribution:UpdateDistribution",)
[docs] def process(self, distributions): client = local_session(self.manager.session_factory).client( self.manager.get_model().service) for d in distributions: self.process_distribution(client, d)
[docs] def process_distribution(self, client, distribution): try: res = client.get_distribution_config( Id=distribution[self.manager.get_model().id]) etag = res['ETag'] dc = res['DistributionConfig'] for item in dc['CacheBehaviors'].get('Items', []): item['ViewerProtocolPolicy'] = self.data.get( 'ViewerProtocolPolicy', item['ViewerProtocolPolicy']) dc['DefaultCacheBehavior']['ViewerProtocolPolicy'] = self.data.get( 'ViewerProtocolPolicy', dc['DefaultCacheBehavior']['ViewerProtocolPolicy']) for item in dc['Origins'].get('Items', []): if item.get('CustomOriginConfig', False): item['CustomOriginConfig']['OriginProtocolPolicy'] = self.data.get( 'OriginProtocolPolicy', item['CustomOriginConfig']['OriginProtocolPolicy']) item['CustomOriginConfig']['OriginSslProtocols']['Items'] = self.data.get( 'OriginSslProtocols', item['CustomOriginConfig']['OriginSslProtocols']['Items']) item['CustomOriginConfig']['OriginSslProtocols']['Quantity'] = len( item['CustomOriginConfig']['OriginSslProtocols']['Items']) res = client.update_distribution( Id=distribution[self.manager.get_model().id], IfMatch=etag, DistributionConfig=dc ) except Exception as e: self.log.warning( "Exception trying to force ssl on Distribution: %s error: %s", distribution['ARN'], e) return