# Copyright 2018 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import datetime
import io
import logging
import os
import shutil
import tempfile
import unittest
import mock
import six
import yaml
from c7n import policy
from c7n.schema import generate, validate as schema_validate
from c7n.ctx import ExecutionContext
from c7n.utils import reset_session_cache
from c7n.config import Bag, Config
C7N_VALIDATE = bool(os.environ.get("C7N_VALIDATE", ""))
skip_if_not_validating = unittest.skipIf(
not C7N_VALIDATE, reason="We are not validating schemas."
)
try:
import pytest
functional = pytest.mark.functional
except ImportError:
functional = lambda func: func # noqa E731
[docs]class TestUtils(unittest.TestCase):
custodian_schema = None
[docs] def tearDown(self):
self.cleanUp()
[docs] def cleanUp(self):
# Clear out thread local session cache
reset_session_cache()
[docs] def write_policy_file(self, policy, format="yaml"):
""" Write a policy file to disk in the specified format.
Input a dictionary and a format. Valid formats are `yaml` and `json`
Returns the file path.
"""
fh = tempfile.NamedTemporaryFile(mode="w+b", suffix="." + format, delete=False)
if format == "json":
fh.write(json.dumps(policy).encode("utf8"))
else:
fh.write(yaml.dump(policy, encoding="utf8", Dumper=yaml.SafeDumper))
fh.flush()
self.addCleanup(os.unlink, fh.name)
self.addCleanup(fh.close)
return fh.name
[docs] def get_temp_dir(self):
""" Return a temporary directory that will get cleaned up. """
temp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, temp_dir)
return temp_dir
[docs] def get_context(self, config=None, session_factory=None, policy=None):
if config is None:
self.context_output_dir = self.get_temp_dir()
config = Config.empty(output_dir=self.context_output_dir)
ctx = ExecutionContext(
session_factory, policy or Bag({
"name": "test-policy", "provider_name": "aws"}), config)
return ctx
[docs] def load_policy(
self,
data,
config=None,
session_factory=None,
validate=C7N_VALIDATE,
output_dir=None,
cache=False,
):
if validate:
if not self.custodian_schema:
self.custodian_schema = generate()
errors = schema_validate({"policies": [data]}, self.custodian_schema)
if errors:
raise errors[0]
config = config or {}
if not output_dir:
temp_dir = self.get_temp_dir()
config["output_dir"] = temp_dir
if cache:
config["cache"] = os.path.join(temp_dir, "c7n.cache")
config["cache_period"] = 300
conf = Config.empty(**config)
p = policy.Policy(data, conf, session_factory)
p.validate()
return p
[docs] def load_policy_set(self, data, config=None):
filename = self.write_policy_file(data)
if config:
e = Config.empty(**config)
else:
e = Config.empty()
return policy.load(e, filename)
[docs] def patch(self, obj, attr, new):
old = getattr(obj, attr, None)
setattr(obj, attr, new)
self.addCleanup(setattr, obj, attr, old)
[docs] def change_cwd(self, work_dir=None):
if work_dir is None:
work_dir = self.get_temp_dir()
cur_dir = os.path.abspath(os.getcwd())
def restore():
os.chdir(cur_dir)
self.addCleanup(restore)
os.chdir(work_dir)
return work_dir
[docs] def change_environment(self, **kwargs):
"""Change the environment to the given set of variables.
To clear an environment variable set it to None.
Existing environment restored after test.
"""
# preserve key elements needed for testing
for env in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"]:
if env not in kwargs:
kwargs[env] = os.environ.get(env, "")
original_environ = dict(os.environ)
@self.addCleanup
def cleanup_env():
os.environ.clear()
os.environ.update(original_environ)
os.environ.clear()
for key, value in list(kwargs.items()):
if value is None:
del (kwargs[key])
os.environ.update(kwargs)
[docs] def capture_logging(
self, name=None, level=logging.INFO, formatter=None, log_file=None
):
if log_file is None:
log_file = TextTestIO()
log_handler = logging.StreamHandler(log_file)
if formatter:
log_handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.addHandler(log_handler)
old_logger_level = logger.level
logger.setLevel(level)
@self.addCleanup
def reset_logging():
logger.removeHandler(log_handler)
logger.setLevel(old_logger_level)
return log_file
[docs]class TextTestIO(io.StringIO):
[docs] def write(self, b):
# print handles both str/bytes and unicode/str, but io.{String,Bytes}IO
# requires us to choose. We don't have control over all of the places
# we want to print from (think: traceback.print_exc) so we can't
# standardize the arg type up at the call sites. Hack it here.
if not isinstance(b, six.text_type):
b = b.decode("utf8")
return super(TextTestIO, self).write(b)
# Per http://blog.xelnor.net/python-mocking-datetime/
# naive implementation has issues with pypy
real_datetime_class = datetime.datetime
[docs]def mock_datetime_now(tgt, dt):
class DatetimeSubclassMeta(type):
@classmethod
def __instancecheck__(mcs, obj):
return isinstance(obj, real_datetime_class)
class BaseMockedDatetime(real_datetime_class):
target = tgt
@classmethod
def now(cls, tz=None):
return cls.target.replace(tzinfo=tz)
@classmethod
def utcnow(cls):
return cls.target
# Python2 & Python3 compatible metaclass
MockedDatetime = DatetimeSubclassMeta(
b"datetime" if str is bytes else "datetime", # hack Python2/3 port
(BaseMockedDatetime,),
{},
)
return mock.patch.object(dt, "datetime", MockedDatetime)