diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3de79a8261..982f6a6f3d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ in development * Allow user to pass a boolean value for the ``cacert`` st2client constructor argument. This way it now mimics the behavior of the ``verify`` argument of the ``requests.request`` method. (improvement) +* Add datastore access to Python actions. (new-feature) #2396 [Kale Blankenship] 1.3.0 - January 22, 2016 ------------------------ diff --git a/st2actions/st2actions/runners/python_action_wrapper.py b/st2actions/st2actions/runners/python_action_wrapper.py index 80e443b731..26b7f5b0af 100644 --- a/st2actions/st2actions/runners/python_action_wrapper.py +++ b/st2actions/st2actions/runners/python_action_wrapper.py @@ -16,6 +16,7 @@ import sys import json import argparse +import logging as stdlib_logging from st2common import log as logging from st2actions import config @@ -23,6 +24,8 @@ from st2common.util import loader as action_loader from st2common.util.config_parser import ContentPackConfigParser from st2common.constants.action import ACTION_OUTPUT_RESULT_DELIMITER +from st2common.service_setup import db_setup +from st2common.services.datastore import DatastoreService __all__ = [ 'PythonActionWrapper' @@ -46,10 +49,14 @@ def __init__(self, pack, file_path, parameters=None, parent_args=None): :param parent_args: Command line arguments passed to the parent process. :type parse_args: ``list`` """ + db_setup() + self._pack = pack self._file_path = file_path self._parameters = parameters or {} self._parent_args = parent_args or [] + self._class_name = None + self._logger = logging.getLogger('PythonActionWrapper') try: config.parse_args(args=self._parent_args) @@ -85,10 +92,37 @@ def _get_action_instance(self): LOG.info('Using config "%s" for action "%s"' % (config.file_path, self._file_path)) - return action_cls(config=config.config) + action_instance = action_cls(config=config.config) else: LOG.info('No config found for action "%s"' % (self._file_path)) - return action_cls(config={}) + action_instance = action_cls(config={}) + + # Setup action_instance proeprties + action_instance.logger = self._set_up_logger(action_cls.__name__) + action_instance.datastore = DatastoreService(logger=action_instance.logger, + pack_name=self._pack, + class_name=action_cls.__name__, + api_username="action_service") + + return action_instance + + def _set_up_logger(self, action_name): + """ + Set up a logger which logs all the messages with level DEBUG + and above to stderr. + """ + logger_name = 'actions.python.%s' % (action_name) + logger = logging.getLogger(logger_name) + + console = stdlib_logging.StreamHandler() + console.setLevel(stdlib_logging.DEBUG) + + formatter = stdlib_logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + console.setFormatter(formatter) + logger.addHandler(console) + logger.setLevel(stdlib_logging.DEBUG) + + return logger if __name__ == '__main__': diff --git a/st2actions/st2actions/runners/pythonrunner.py b/st2actions/st2actions/runners/pythonrunner.py index 66da0f8d35..4e8ab40e0a 100644 --- a/st2actions/st2actions/runners/pythonrunner.py +++ b/st2actions/st2actions/runners/pythonrunner.py @@ -18,14 +18,12 @@ import abc import json import uuid -import logging as stdlib_logging import six from eventlet.green import subprocess from st2actions.runners import ActionRunner from st2common.util.green.shell import run_command -from st2common import log as logging from st2common.constants.action import ACTION_OUTPUT_RESULT_DELIMITER from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED from st2common.constants.action import LIVEACTION_STATUS_FAILED @@ -44,8 +42,6 @@ 'Action' ] -LOG = logging.getLogger(__name__) - # constants to lookup in runner_parameters. RUNNER_ENV = 'env' RUNNER_TIMEOUT = 'timeout' @@ -73,30 +69,14 @@ def __init__(self, config=None): :type config: ``dict`` """ self.config = config or {} - self.logger = self._set_up_logger() + # logger and datastore are assigned in PythonActionWrapper._get_action_instance + self.logger = None + self.datastore = None @abc.abstractmethod def run(self, **kwargs): pass - def _set_up_logger(self): - """ - Set up a logger which logs all the messages with level DEBUG - and above to stderr. - """ - logger_name = 'actions.python.%s' % (self.__class__.__name__) - logger = logging.getLogger(logger_name) - - console = stdlib_logging.StreamHandler() - console.setLevel(stdlib_logging.DEBUG) - - formatter = stdlib_logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') - console.setFormatter(formatter) - logger.addHandler(console) - logger.setLevel(stdlib_logging.DEBUG) - - return logger - class PythonRunner(ActionRunner): diff --git a/st2common/st2common/services/datastore.py b/st2common/st2common/services/datastore.py new file mode 100644 index 0000000000..d649a499c3 --- /dev/null +++ b/st2common/st2common/services/datastore.py @@ -0,0 +1,219 @@ +# Licensed to the StackStorm, Inc ('StackStorm') under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from st2client.client import Client +from st2client.models import KeyValuePair +from st2common.services.access import create_token +from st2common.util.api import get_full_public_api_url + + +class DatastoreService(object): + """ + Class provides public methods for accessing datastore items. + """ + + DATASTORE_NAME_SEPARATOR = ':' + + def __init__(self, logger, pack_name, class_name, api_username): + self._api_username = api_username + self._pack_name = pack_name + self._class_name = class_name + self._logger = logger + + self._client = None + + ################################## + # Methods for datastore management + ################################## + + def list_values(self, local=True, prefix=None): + """ + Retrieve all the datastores items. + + :param local: List values from a namespace local to this pack/class. Defaults to True. + :type: local: ``bool`` + + :param prefix: Optional key name prefix / startswith filter. + :type prefix: ``str`` + + :rtype: ``list`` of :class:`KeyValuePair` + """ + client = self._get_api_client() + self._logger.audit('Retrieving all the value from the datastore') + + key_prefix = self._get_full_key_prefix(local=local, prefix=prefix) + kvps = client.keys.get_all(prefix=key_prefix) + return kvps + + def get_value(self, name, local=True): + """ + Retrieve a value from the datastore for the provided key. + + By default, value is retrieved from the namespace local to the pack/class. If you want to + retrieve a global value from a datastore, pass local=False to this method. + + :param name: Key name. + :type name: ``str`` + + :param local: Retrieve value from a namespace local to the pack/class. Defaults to True. + :type: local: ``bool`` + + :rtype: ``str`` or ``None`` + """ + name = self._get_full_key_name(name=name, local=local) + + client = self._get_api_client() + self._logger.audit('Retrieving value from the datastore (name=%s)', name) + + try: + kvp = client.keys.get_by_id(id=name) + except Exception: + return None + + if kvp: + return kvp.value + + return None + + def set_value(self, name, value, ttl=None, local=True): + """ + Set a value for the provided key. + + By default, value is set in a namespace local to the pack/class. If you want to + set a global value, pass local=False to this method. + + :param name: Key name. + :type name: ``str`` + + :param value: Key value. + :type value: ``str`` + + :param ttl: Optional TTL (in seconds). + :type ttl: ``int`` + + :param local: Set value in a namespace local to the pack/class. Defaults to True. + :type: local: ``bool`` + + :return: ``True`` on success, ``False`` otherwise. + :rtype: ``bool`` + """ + name = self._get_full_key_name(name=name, local=local) + + value = str(value) + client = self._get_api_client() + + self._logger.audit('Setting value in the datastore (name=%s)', name) + + instance = KeyValuePair() + instance.id = name + instance.name = name + instance.value = value + + if ttl: + instance.ttl = ttl + + client.keys.update(instance=instance) + return True + + def delete_value(self, name, local=True): + """ + Delete the provided key. + + By default, value is deleted from a namespace local to the pack/class. If you want to + delete a global value, pass local=False to this method. + + :param name: Name of the key to delete. + :type name: ``str`` + + :param local: Delete a value in a namespace local to the pack/class. Defaults to True. + :type: local: ``bool`` + + :return: ``True`` on success, ``False`` otherwise. + :rtype: ``bool`` + """ + name = self._get_full_key_name(name=name, local=local) + + client = self._get_api_client() + + instance = KeyValuePair() + instance.id = name + instance.name = name + + self._logger.audit('Deleting value from the datastore (name=%s)', name) + + try: + client.keys.delete(instance=instance) + except Exception: + return False + + return True + + def _get_api_client(self): + """ + Retrieve API client instance. + """ + if not self._client: + ttl = (24 * 60 * 60) + temporary_token = create_token(username=self._api_username, ttl=ttl) + api_url = get_full_public_api_url() + self._client = Client(api_url=api_url, token=temporary_token.token) + + return self._client + + def _get_full_key_name(self, name, local): + """ + Retrieve a full key name. + + :rtype: ``str`` + """ + if local: + name = self._get_key_name_with_prefix(name=name) + + return name + + def _get_full_key_prefix(self, local, prefix=None): + if local: + key_prefix = self._get_local_key_name_prefix() + + if prefix: + key_prefix += prefix + else: + key_prefix = prefix + + return key_prefix + + def _get_local_key_name_prefix(self): + """ + Retrieve key prefix which is local to this pack/class. + """ + key_prefix = self._get_datastore_key_prefix() + self.DATASTORE_NAME_SEPARATOR + return key_prefix + + def _get_key_name_with_prefix(self, name): + """ + Retrieve a full key name which is local to the current pack/class. + + :param name: Base datastore key name. + :type name: ``str`` + + :rtype: ``str`` + """ + prefix = self._get_datastore_key_prefix() + full_name = prefix + self.DATASTORE_NAME_SEPARATOR + name + return full_name + + def _get_datastore_key_prefix(self): + prefix = '%s.%s' % (self._pack_name, self._class_name) + return prefix diff --git a/st2common/tests/unit/test_datastore.py b/st2common/tests/unit/test_datastore.py new file mode 100644 index 0000000000..9d383c598b --- /dev/null +++ b/st2common/tests/unit/test_datastore.py @@ -0,0 +1,110 @@ +# Licensed to the StackStorm, Inc ('StackStorm') under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import unittest2 + +import mock + +from st2common.services.datastore import DatastoreService +from st2client.models.keyvalue import KeyValuePair + +CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) + + +class DatastoreServiceTestCase(unittest2.TestCase): + def setUp(self): + super(DatastoreServiceTestCase, self).setUp() + + self._datastore_service = DatastoreService(logger=mock.Mock(), + pack_name='core', + class_name='TestSensor', + api_username='sensor_service') + self._datastore_service._get_api_client = mock.Mock() + + def test_datastore_operations_list_values(self): + # Verify prefix filtering + mock_api_client = mock.Mock() + mock_api_client.keys.get_all.return_value = [] + self._set_mock_api_client(mock_api_client) + + self._datastore_service.list_values(local=True, prefix=None) + mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:') + self._datastore_service.list_values(local=True, prefix='ponies') + mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:ponies') + + self._datastore_service.list_values(local=False, prefix=None) + mock_api_client.keys.get_all.assert_called_with(prefix=None) + self._datastore_service.list_values(local=False, prefix='ponies') + mock_api_client.keys.get_all.assert_called_with(prefix='ponies') + + # No values in the datastore + mock_api_client = mock.Mock() + mock_api_client.keys.get_all.return_value = [] + self._set_mock_api_client(mock_api_client) + + values = self._datastore_service.list_values(local=True) + self.assertEqual(values, []) + values = self._datastore_service.list_values(local=False) + self.assertEqual(values, []) + + # Values in the datastore + kvp1 = KeyValuePair() + kvp1.name = 'test1' + kvp1.value = 'bar' + kvp2 = KeyValuePair() + kvp2.name = 'test2' + kvp2.value = 'bar' + mock_return_value = [kvp1, kvp2] + mock_api_client.keys.get_all.return_value = mock_return_value + self._set_mock_api_client(mock_api_client) + + values = self._datastore_service.list_values(local=True) + self.assertEqual(len(values), 2) + self.assertEqual(values, mock_return_value) + + def test_datastore_operations_get_value(self): + mock_api_client = mock.Mock() + kvp1 = KeyValuePair() + kvp1.name = 'test1' + kvp1.value = 'bar' + mock_api_client.keys.get_by_id.return_value = kvp1 + self._set_mock_api_client(mock_api_client) + + value = self._datastore_service.get_value(name='test1', local=False) + self.assertEqual(value, kvp1.value) + + def test_datastore_operations_set_value(self): + mock_api_client = mock.Mock() + mock_api_client.keys.update.return_value = True + self._set_mock_api_client(mock_api_client) + + value = self._datastore_service.set_value(name='test1', value='foo', local=False) + self.assertTrue(value) + + def test_datastore_operations_delete_value(self): + mock_api_client = mock.Mock() + mock_api_client.keys.delete.return_value = True + self._set_mock_api_client(mock_api_client) + + value = self._datastore_service.delete_value(name='test', local=False) + self.assertTrue(value) + + def _set_mock_api_client(self, mock_api_client): + mock_method = mock.Mock() + mock_method.return_value = mock_api_client + self._datastore_service._get_api_client = mock_method diff --git a/st2reactor/st2reactor/container/sensor_wrapper.py b/st2reactor/st2reactor/container/sensor_wrapper.py index b7020d4afc..44dabbc7ca 100644 --- a/st2reactor/st2reactor/container/sensor_wrapper.py +++ b/st2reactor/st2reactor/container/sensor_wrapper.py @@ -21,7 +21,6 @@ import eventlet from oslo_config import cfg -from st2client.client import Client from st2common import log as logging from st2common.logging.misc import set_log_level_for_all_loggers @@ -33,9 +32,7 @@ from st2common.services.triggerwatcher import TriggerWatcher from st2reactor.sensor.base import Sensor, PollingSensor from st2reactor.sensor import config -from st2common.constants.system import API_URL_ENV_VARIABLE_NAME -from st2common.constants.system import AUTH_TOKEN_ENV_VARIABLE_NAME -from st2client.models.keyvalue import KeyValuePair +from st2common.services.datastore import DatastoreService __all__ = [ 'SensorWrapper' @@ -55,12 +52,14 @@ class SensorService(object): methods which can be called by the sensor. """ - DATASTORE_NAME_SEPARATOR = ':' - def __init__(self, sensor_wrapper): self._sensor_wrapper = sensor_wrapper self._logger = self._sensor_wrapper._logger self._dispatcher = TriggerDispatcher(self._logger) + self._datastore_service = DatastoreService(logger=self._logger, + pack_name=self._sensor_wrapper._pack, + class_name=self._sensor_wrapper._class_name, + api_username='sensor_service') self._client = None @@ -111,190 +110,16 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None): ################################## def list_values(self, local=True, prefix=None): - """ - Retrieve all the datastores items. - - :param local: List values from a namespace local to this sensor. Defaults to True. - :type: local: ``bool`` - - :param prefix: Optional key name prefix / startswith filter. - :type prefix: ``str`` - - :rtype: ``list`` of :class:`KeyValuePair` - """ - client = self._get_api_client() - self._logger.audit('Retrieving all the value from the datastore') - - key_prefix = self._get_full_key_prefix(local=local, prefix=prefix) - kvps = client.keys.get_all(prefix=key_prefix) - return kvps + return self._datastore_service.list_values(local, prefix) def get_value(self, name, local=True): - """ - Retrieve a value from the datastore for the provided key. - - By default, value is retrieved from the namespace local to the sensor. If you want to - retrieve a global value from a datastore, pass local=False to this method. - - :param name: Key name. - :type name: ``str`` - - :param local: Retrieve value from a namespace local to the sensor. Defaults to True. - :type: local: ``bool`` - - :rtype: ``str`` or ``None`` - """ - name = self._get_full_key_name(name=name, local=local) - - client = self._get_api_client() - self._logger.audit('Retrieving value from the datastore (name=%s)', name) - - try: - kvp = client.keys.get_by_id(id=name) - except Exception: - return None - - if kvp: - return kvp.value - - return None + return self._datastore_service.get_value(name, local) def set_value(self, name, value, ttl=None, local=True): - """ - Set a value for the provided key. - - By default, value is set in a namespace local to the sensor. If you want to - set a global value, pass local=False to this method. - - :param name: Key name. - :type name: ``str`` - - :param value: Key value. - :type value: ``str`` - - :param ttl: Optional TTL (in seconds). - :type ttl: ``int`` - - :param local: Set value in a namespace local to the sensor. Defaults to True. - :type: local: ``bool`` - - :return: ``True`` on success, ``False`` otherwise. - :rtype: ``bool`` - """ - name = self._get_full_key_name(name=name, local=local) - - value = str(value) - client = self._get_api_client() - - self._logger.audit('Setting value in the datastore (name=%s)', name) - - instance = KeyValuePair() - instance.id = name - instance.name = name - instance.value = value - - if ttl: - instance.ttl = ttl - - client.keys.update(instance=instance) - return True + return self._datastore_service.set_value(name, value, ttl, local) def delete_value(self, name, local=True): - """ - Delete the provided key. - - By default, value is deleted from a namespace local to the sensor. If you want to - delete a global value, pass local=False to this method. - - :param name: Name of the key to delete. - :type name: ``str`` - - :param local: Delete a value in a namespace local to the sensor. Defaults to True. - :type: local: ``bool`` - - :return: ``True`` on success, ``False`` otherwise. - :rtype: ``bool`` - """ - name = self._get_full_key_name(name=name, local=local) - - client = self._get_api_client() - - instance = KeyValuePair() - instance.id = name - instance.name = name - - self._logger.audit('Deleting value from the datastore (name=%s)', name) - - try: - client.keys.delete(instance=instance) - except Exception: - return False - - return True - - def _get_api_client(self): - """ - Retrieve API client instance. - """ - # TODO: API client is really unfriendly and needs to be re-designed and - # improved - api_url = os.environ.get(API_URL_ENV_VARIABLE_NAME, None) - auth_token = os.environ.get(AUTH_TOKEN_ENV_VARIABLE_NAME, None) - - if not api_url or not auth_token: - raise ValueError('%s and %s environment variable must be set' % - (API_URL_ENV_VARIABLE_NAME, AUTH_TOKEN_ENV_VARIABLE_NAME)) - - if not self._client: - self._client = Client(api_url=api_url) - - return self._client - - def _get_full_key_name(self, name, local): - """ - Retrieve a full key name. - - :rtype: ``str`` - """ - if local: - name = self._get_key_name_with_sensor_prefix(name=name) - - return name - - def _get_full_key_prefix(self, local, prefix=None): - if local: - key_prefix = self._get_sensor_local_key_name_prefix() - - if prefix: - key_prefix += prefix - else: - key_prefix = prefix - - return key_prefix - - def _get_sensor_local_key_name_prefix(self): - """ - Retrieve key prefix which is local to this sensor. - """ - key_prefix = self._get_datastore_key_prefix() + self.DATASTORE_NAME_SEPARATOR - return key_prefix - - def _get_key_name_with_sensor_prefix(self, name): - """ - Retrieve a full key name which is local to the current sensor. - - :param name: Base datastore key name. - :type name: ``str`` - - :rtype: ``str`` - """ - prefix = self._get_datastore_key_prefix() - full_name = prefix + self.DATASTORE_NAME_SEPARATOR + name - return full_name - - def _get_datastore_key_prefix(self): - prefix = '%s.%s' % (self._sensor_wrapper._pack, self._sensor_wrapper._class_name) - return prefix + return self._datastore_service.delete_value(name, local) class SensorWrapper(object): diff --git a/st2reactor/tests/unit/test_sensor_wrapper.py b/st2reactor/tests/unit/test_sensor_wrapper.py index 490f1facad..1b5fb288d9 100644 --- a/st2reactor/tests/unit/test_sensor_wrapper.py +++ b/st2reactor/tests/unit/test_sensor_wrapper.py @@ -6,9 +6,7 @@ import st2tests.config as tests_config from st2tests.base import TESTS_CONFIG_PATH from st2reactor.container.sensor_wrapper import SensorWrapper -from st2reactor.container.sensor_wrapper import SensorService from st2reactor.sensor.base import Sensor, PollingSensor -from st2client.models.keyvalue import KeyValuePair CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) @@ -93,96 +91,3 @@ def test_sensor_creation_active(self): self.assertIsNotNone(wrapper._sensor_instance) self.assertIsInstance(wrapper._sensor_instance, PollingSensor) self.assertEquals(wrapper._sensor_instance._poll_interval, poll_interval) - - -class SensorServiceTestCase(unittest2.TestCase): - @classmethod - def setUpClass(cls): - super(SensorServiceTestCase, cls).setUpClass() - tests_config.parse_args() - - def setUp(self): - super(SensorServiceTestCase, self).setUp() - - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) - self._sensor_service = SensorService(sensor_wrapper=wrapper) - self._sensor_service._get_api_client = mock.Mock() - - def test_datastore_operations_list_values(self): - # Verify prefix filtering - mock_api_client = mock.Mock() - mock_api_client.keys.get_all.return_value = [] - self._set_mock_api_client(mock_api_client) - - self._sensor_service.list_values(local=True, prefix=None) - mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:') - self._sensor_service.list_values(local=True, prefix='ponies') - mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:ponies') - - self._sensor_service.list_values(local=False, prefix=None) - mock_api_client.keys.get_all.assert_called_with(prefix=None) - self._sensor_service.list_values(local=False, prefix='ponies') - mock_api_client.keys.get_all.assert_called_with(prefix='ponies') - - # No values in the datastore - mock_api_client = mock.Mock() - mock_api_client.keys.get_all.return_value = [] - self._set_mock_api_client(mock_api_client) - - values = self._sensor_service.list_values(local=True) - self.assertEqual(values, []) - values = self._sensor_service.list_values(local=False) - self.assertEqual(values, []) - - # Values in the datastore - kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' - kvp2 = KeyValuePair() - kvp2.name = 'test2' - kvp2.value = 'bar' - mock_return_value = [kvp1, kvp2] - mock_api_client.keys.get_all.return_value = mock_return_value - self._set_mock_api_client(mock_api_client) - - values = self._sensor_service.list_values(local=True) - self.assertEqual(len(values), 2) - self.assertEqual(values, mock_return_value) - - def test_datastore_operations_get_value(self): - mock_api_client = mock.Mock() - kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' - mock_api_client.keys.get_by_id.return_value = kvp1 - self._set_mock_api_client(mock_api_client) - - value = self._sensor_service.get_value(name='test1', local=False) - self.assertEqual(value, kvp1.value) - - def test_datastore_operations_set_value(self): - mock_api_client = mock.Mock() - mock_api_client.keys.update.return_value = True - self._set_mock_api_client(mock_api_client) - - value = self._sensor_service.set_value(name='test1', value='foo', local=False) - self.assertTrue(value) - - def test_datastore_operations_delete_value(self): - mock_api_client = mock.Mock() - mock_api_client.keys.delete.return_value = True - self._set_mock_api_client(mock_api_client) - - value = self._sensor_service.delete_value(name='test', local=False) - self.assertTrue(value) - - def _set_mock_api_client(self, mock_api_client): - mock_method = mock.Mock() - mock_method.return_value = mock_api_client - self._sensor_service._get_api_client = mock_method diff --git a/st2tests/st2tests/mocks/datastore.py b/st2tests/st2tests/mocks/datastore.py new file mode 100644 index 0000000000..3867a0cc17 --- /dev/null +++ b/st2tests/st2tests/mocks/datastore.py @@ -0,0 +1,95 @@ +# Licensed to the StackStorm, Inc ('StackStorm') under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mock classes for use in pack testing. +""" + +from st2common.services.datastore import DatastoreService +from st2client.models.keyvalue import KeyValuePair + +__all__ = [ + 'MockDatastoreService' +] + + +class MockDatastoreService(DatastoreService): + """ + Mock DatastoreService for use in testing. + """ + def __init__(self, logger, pack_name, class_name, api_username): + self._pack_name = pack_name + self._class_name = class_name + + # Holds mock KeyValuePair objects + # Key is a KeyValuePair name and value is the KeyValuePair object + self._datastore_items = {} + + def list_values(self, local=True, prefix=None): + """ + Return a list of all values stored in a dictionary which is local to this class. + """ + key_prefix = self._get_full_key_prefix(local=local, prefix=prefix) + + if not key_prefix: + return self._datastore_items.values() + + result = [] + for name, kvp in self._datastore_items.items(): + if name.startswith(key_prefix): + result.append(kvp) + + return result + + def get_value(self, name, local=True): + """ + Return a particular value stored in a dictionary which is local to this class. + """ + name = self._get_full_key_name(name=name, local=local) + + if name not in self._datastore_items: + return None + + kvp = self._datastore_items[name] + return kvp.value + + def set_value(self, name, value, ttl=None, local=True): + """ + Store a value in a dictionary which is local to this class. + """ + if ttl: + raise ValueError('MockDatastoreService.set_value doesn\'t support "ttl" argument') + + name = self._get_full_key_name(name=name, local=local) + + instance = KeyValuePair() + instance.id = name + instance.name = name + instance.value = value + + self._datastore_items[name] = instance + return True + + def delete_value(self, name, local=True): + """ + Delete a value from a dictionary which is local to this class. + """ + name = self._get_full_key_name(name=name, local=local) + + if name not in self._datastore_items: + return False + + del self._datastore_items[name] + return True diff --git a/st2tests/st2tests/mocks/sensor.py b/st2tests/st2tests/mocks/sensor.py index 464a8dbfd1..9b89435fdc 100644 --- a/st2tests/st2tests/mocks/sensor.py +++ b/st2tests/st2tests/mocks/sensor.py @@ -22,7 +22,7 @@ from mock import Mock from st2reactor.container.sensor_wrapper import SensorService -from st2client.models.keyvalue import KeyValuePair +from st2tests.mocks.datastore import MockDatastoreService __all__ = [ 'MockSensorWrapper', @@ -48,13 +48,14 @@ def __init__(self, sensor_wrapper): # We use a Mock class so use can assert logger was called with particular arguments self._logger = Mock(spec=RootLogger) - # Holds mock KeyValuePair objects - # Key is a KeyValuePair name and value is the KeyValuePair object - self._datastore_items = {} - # Holds a list of triggers which were dispatched self.dispatched_triggers = [] + self._datastore_service = MockDatastoreService(logger=self._logger, + pack_name=self._sensor_wrapper._pack, + class_name=self._sensor_wrapper._class_name, + api_username='sensor_service') + def get_logger(self, name): """ Return mock logger instance. @@ -72,60 +73,3 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None): 'trace_context': trace_context } self.dispatched_triggers.append(item) - - def list_values(self, local=True, prefix=None): - """ - Return a list of all values stored in a dictionary which is local to this class. - """ - key_prefix = self._get_full_key_prefix(local=local, prefix=prefix) - - if not key_prefix: - return self._datastore_items.values() - - result = [] - for name, kvp in self._datastore_items.items(): - if name.startswith(key_prefix): - result.append(kvp) - - return result - - def get_value(self, name, local=True): - """ - Return a particular value stored in a dictionary which is local to this class. - """ - name = self._get_full_key_name(name=name, local=local) - - if name not in self._datastore_items: - return None - - kvp = self._datastore_items[name] - return kvp.value - - def set_value(self, name, value, ttl=None, local=True): - """ - Store a value in a dictionary which is local to this class. - """ - if ttl: - raise ValueError('MockSensorService.set_value doesn\'t support "ttl" argument') - - name = self._get_full_key_name(name=name, local=local) - - instance = KeyValuePair() - instance.id = name - instance.name = name - instance.value = value - - self._datastore_items[name] = instance - return True - - def delete_value(self, name, local=True): - """ - Delete a value from a dictionary which is local to this class. - """ - name = self._get_full_key_name(name=name, local=local) - - if name not in self._datastore_items: - return False - - del self._datastore_items[name] - return True