Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable WireClient tests #1800

Merged
merged 5 commits into from
Mar 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions azurelinuxagent/common/protocol/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,6 @@ def __init__(self, endpoint):
logger.info("Wire server endpoint:{0}", endpoint)
self._endpoint = endpoint
self._goal_state = None
self._hosting_env = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to remove these on #1777

self._shared_conf = None
self._remote_access = None
self._certs = None
self._ext_conf = None
self._host_plugin = None
self.status_blob = StatusBlob(self)
self.goal_state_flusher = StateFlusher(conf.get_lib_dir())
Expand Down
5 changes: 3 additions & 2 deletions tests/common/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
WALAEventOperation, parse_xml_event, parse_json_event, AGENT_EVENT_FILE_EXTENSION, EVENTS_DIRECTORY
from azurelinuxagent.common.future import ustr
from azurelinuxagent.common.protocol.goal_state import GoalState
from tests.protocol import mockwiredata, mock_wire_protocol
from tests.protocol import mockwiredata
from tests.protocol.mocks import mock_wire_protocol
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved mock_wire_protocol to a different file (mocks.py)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also changed the usage from

with mock_wire_protocol.create(...)...

to

with mock_wire_protocol(...)...

As I'm adding more functionality to the mocks the latter feels better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup - Looks cleaner. Thanks

from azurelinuxagent.common.version import CURRENT_AGENT, CURRENT_VERSION, AGENT_EXECUTION_MODE
from azurelinuxagent.common.osutil import get_osutil
from tests.tools import AgentTestCase, data_dir, load_data, Mock, patch, skip_if_predicate_true
Expand Down Expand Up @@ -98,7 +99,7 @@ def create_event_and_return_container_id():

self.fail("Could not find Contained ID on event")

with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol:
with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
contained_id = create_event_and_return_container_id()
# The expect value comes from DATA_FILE
self.assertEquals(contained_id, 'c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2', "Incorrect container ID")
Expand Down
1 change: 0 additions & 1 deletion tests/daemon/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from azurelinuxagent.daemon import *
from azurelinuxagent.daemon.main import OPENSSL_FIPS_ENVIRONMENT
from azurelinuxagent.pa.provision.default import ProvisionHandler
from tests.protocol import mock_wire_protocol, mockwiredata
from tests.tools import AgentTestCase, Mock, patch


Expand Down
28 changes: 10 additions & 18 deletions tests/ga/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,20 @@

from azurelinuxagent.common import event, logger
from azurelinuxagent.common.cgroup import CGroup, CpuCgroup, MemoryCgroup
from azurelinuxagent.common.cgroupconfigurator import CGroupConfigurator
from azurelinuxagent.common.cgroupstelemetry import CGroupsTelemetry, MetricValue
from azurelinuxagent.common.datacontract import get_properties
from azurelinuxagent.common.event import EventLogger, WALAEventOperation, EVENTS_DIRECTORY
from azurelinuxagent.common.event import WALAEventOperation, EVENTS_DIRECTORY
from azurelinuxagent.common.exception import HttpError
from azurelinuxagent.common.future import ustr
from azurelinuxagent.common.logger import Logger
from azurelinuxagent.common.osutil import get_osutil
from azurelinuxagent.common.osutil.default import BASE_CGROUPS, DefaultOSUtil
from azurelinuxagent.common.protocol.wire import ExtHandler, ExtHandlerProperties
from azurelinuxagent.common.protocol.wire import WireProtocol
from azurelinuxagent.common.telemetryevent import TelemetryEvent, TelemetryEventParam
from azurelinuxagent.common.utils import fileutil, restutil
from azurelinuxagent.common.version import AGENT_VERSION, CURRENT_VERSION, CURRENT_AGENT, DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME
from azurelinuxagent.ga.exthandlers import ExtHandlerInstance
from azurelinuxagent.ga.monitor import generate_extension_metrics_telemetry_dictionary, get_monitor_handler, \
MonitorHandler
from tests.common.test_cgroupstelemetry import make_new_cgroup
from azurelinuxagent.ga.monitor import generate_extension_metrics_telemetry_dictionary, get_monitor_handler, MonitorHandler
from tests.protocol.mockwiredata import DATA_FILE
from tests.protocol import mock_wire_protocol
from tests.tools import Mock, MagicMock, patch, AgentTestCase, data_dir, are_cgroups_enabled, i_am_root, \
skip_if_predicate_false, is_trusty_in_travis, skip_if_predicate_true, clear_singleton_instances, PropertyMock
from tests.protocol.mocks import mock_wire_protocol
from tests.tools import Mock, MagicMock, patch, AgentTestCase, clear_singleton_instances, PropertyMock
from tests.utils.event_logger_tools import EventLoggerTools


Expand Down Expand Up @@ -339,7 +331,7 @@ def _get_event_data(duration, is_success, message, name, op, version, eventId=1)
def test_collect_and_send_events(self, mock_lib_dir, patch_send_event, *_):
mock_lib_dir.return_value = self.lib_dir

with mock_wire_protocol.create(DATA_FILE) as protocol:
with mock_wire_protocol(DATA_FILE) as protocol:
monitor_handler = TestEventMonitoring._create_monitor_handler(protocol)

self._create_extension_event(message="Message-Test")
Expand Down Expand Up @@ -406,7 +398,7 @@ def test_collect_and_send_events(self, mock_lib_dir, patch_send_event, *_):
def test_collect_and_send_events_with_small_events(self, mock_lib_dir, patch_send_event, *_):
mock_lib_dir.return_value = self.lib_dir

with mock_wire_protocol.create(DATA_FILE) as protocol:
with mock_wire_protocol(DATA_FILE) as protocol:
monitor_handler = TestEventMonitoring._create_monitor_handler(protocol)

sizes = [15, 15, 15, 15] # get the powers of 2 - 2**16 is the limit
Expand All @@ -425,7 +417,7 @@ def test_collect_and_send_events_with_small_events(self, mock_lib_dir, patch_sen
def test_collect_and_send_events_with_large_events(self, mock_lib_dir, patch_send_event, *_):
mock_lib_dir.return_value = self.lib_dir

with mock_wire_protocol.create(DATA_FILE) as protocol:
with mock_wire_protocol(DATA_FILE) as protocol:
monitor_handler = TestEventMonitoring._create_monitor_handler(protocol)

sizes = [17, 17, 17] # get the powers of 2
Expand All @@ -446,7 +438,7 @@ def test_collect_and_send_with_http_post_returning_503(self, mock_lib_dir, *_):
mock_lib_dir.return_value = self.lib_dir
fileutil.mkdir(self.event_dir)

with mock_wire_protocol.create(DATA_FILE) as protocol:
with mock_wire_protocol(DATA_FILE) as protocol:
monitor_handler = TestEventMonitoring._create_monitor_handler(protocol)

sizes = [1, 2, 3] # get the powers of 2, and multiple by 1024.
Expand All @@ -472,7 +464,7 @@ def test_collect_and_send_with_send_event_generating_exception(self, mock_lib_di
mock_lib_dir.return_value = self.lib_dir
fileutil.mkdir(self.event_dir)

with mock_wire_protocol.create(DATA_FILE) as protocol:
with mock_wire_protocol(DATA_FILE) as protocol:
monitor_handler = TestEventMonitoring._create_monitor_handler(protocol)

sizes = [1, 2, 3] # get the powers of 2, and multiple by 1024.
Expand All @@ -496,7 +488,7 @@ def test_collect_and_send_with_call_wireserver_returns_http_error(self, mock_lib
mock_lib_dir.return_value = self.lib_dir
fileutil.mkdir(self.event_dir)

with mock_wire_protocol.create(DATA_FILE) as protocol:
with mock_wire_protocol(DATA_FILE) as protocol:
monitor_handler = TestEventMonitoring._create_monitor_handler(protocol)

sizes = [1, 2, 3] # get the powers of 2, and multiple by 1024.
Expand Down
7 changes: 4 additions & 3 deletions tests/ga/test_remoteaccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from azurelinuxagent.common.protocol.goal_state import GoalState, RemoteAccess
from tests.tools import AgentTestCase, load_data, patch, Mock
from tests.protocol import mockwiredata, mock_wire_protocol
from tests.protocol import mockwiredata
from tests.protocol.mocks import mock_wire_protocol


class TestRemoteAccess(AgentTestCase):
Expand All @@ -33,7 +34,7 @@ def test_parse_remote_access(self):
self.assertEquals("2019-01-01", remote_access.user_list.users[0].expiration, "Expiration does not match.")

def test_goal_state_with_no_remote_access(self):
with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol:
with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
self.assertIsNone(protocol.client.get_remote_access())

def test_parse_two_remote_access_accounts(self):
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_parse_zero_remote_access_accounts(self):
self.assertEquals(0, len(remote_access.user_list.users), "User count does not match.")

def test_update_remote_access_conf_remote_access(self):
with mock_wire_protocol.create(mockwiredata.DATA_FILE_REMOTE_ACCESS) as protocol:
with mock_wire_protocol(mockwiredata.DATA_FILE_REMOTE_ACCESS) as protocol:
self.assertIsNotNone(protocol.client.get_remote_access())
self.assertEquals(1, len(protocol.client.get_remote_access().user_list.users))
self.assertEquals('testAccount', protocol.client.get_remote_access().user_list.users[0].name)
Expand Down
84 changes: 0 additions & 84 deletions tests/protocol/mock_wire_protocol.py

This file was deleted.

142 changes: 142 additions & 0 deletions tests/protocol/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2020 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Requires Python 2.6+ and Openssl 1.0+
#
import contextlib
import re
from azurelinuxagent.common.protocol.wire import WireProtocol
from azurelinuxagent.common.utils import restutil
from tests.tools import patch
from tests.protocol.mockwiredata import WireProtocolData


@contextlib.contextmanager
def mock_wire_protocol(mock_wire_data_file):
Copy link
Member Author

@narrieta narrieta Mar 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved here from the previous mock_wire_protocol.py.

Also, I changed it to mock 1 level deeper (http_request instead of http_get) to address all the scenarios in the tests.

"""
Creates a mock WireProtocol object that will return the data specified by 'mock_wire_data_file' (which must
follow the structure of the data files defined in tests/protocol/mockwiredata.py).

NOTE: This function creates mocks for azurelinuxagent.common.utils.restutil.http_request and
azurelinuxagent.common.protocol.wire.CryptUtil. These mocks can be stopped using the
methods stop_mock_http_get() and stop_mock_crypt_util().

The return value is an instance of WireProtocol augmented with these properties/methods:

* mock_wire_data - the WireProtocolData constructed from the mock_wire_data_file parameter.
* stop_mock_http_request() - stops the mock for restutil.http_request
* stop_mock_crypt_util() - stops the mock for CrypUtil
* stop() - stops both mocks
"""
def stop_mock_http_request():
if stop_mock_http_request.mock is not None:
stop_mock_http_request.mock.stop()
stop_mock_http_request.mock = None
stop_mock_http_request.mock = None

def stop_mock_crypt_util():
if stop_mock_crypt_util.mock is not None:
stop_mock_crypt_util.mock.stop()
stop_mock_crypt_util.mock = None
stop_mock_crypt_util.mock = None

def stop():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those stop functions are mainly for debugging purposes; they come handy when there are conflicts with the mocks used by tests and the ones used here.

stop_mock_crypt_util()
stop_mock_http_request()

protocol = WireProtocol(restutil.KNOWN_WIRESERVER_IP)
protocol.mock_wire_data = WireProtocolData(mock_wire_data_file)
protocol.stop_mock_http_request = stop_mock_http_request
protocol.stop_mock_crypt_util = stop_mock_crypt_util
protocol.stop = stop

try:
# To minimize the impact of mocking restutil.http_request we only use the mock data for requests
# to the wireserver or requests starting with "mock-goal-state"
mock_data_re = re.compile(r'https?://(mock-goal-state|{0}).*'.format(restutil.KNOWN_WIRESERVER_IP.replace(r'.', r'\.')), re.IGNORECASE)

original_http_request = restutil.http_request

def http_request(method, url, data, **kwargs):
if method == 'GET' and mock_data_re.match(url) is not None:
return protocol.mock_wire_data.mock_http_get(url, **kwargs)
elif method == 'POST':
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added support for POST

return protocol.mock_wire_data.mock_http_post(url, data, **kwargs)
return original_http_request(method, url, data, **kwargs)

patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request)
patched.start()
stop_mock_http_request.mock = patched

patched = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util)
patched.start()
stop_mock_crypt_util.mock = patched

protocol.detect()

yield protocol

finally:
protocol.stop()


@contextlib.contextmanager
def mock_http_request(http_get_handler=None, http_post_handler=None, http_put_handler=None):
Copy link
Member Author

@narrieta narrieta Mar 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New mock useful for tests that need to interact with the host/storage at the HTTP request level.

I'll probably merge it with mock_wire_protocol since most test scenarios likely will use mock_http_request nested within mock_wire_protocol. Will see first what kind of usage this mock gets in other tests.

"""
Creates a Mock of restutil.http_request that executes the handler given for the corresponding HTTP method.

The return value of the handler function is interpreted similarly to the "return_value" argument of patch(): if it
is an exception the exception is raised or, if it is any object other than None, the value is returned by the mock.

If the handler function returns None the call is passed to the original restutil.http_request.

The patch maintains a list of "tracked" urls. When the handler function returns a value than is not None the url
for the request is automatically added to the tracked list. The handler function can add other items to this list
using the track_url() method on the mock.

The returned Mock is augmented with these 2 methods:

* track_url(url) - adds the given item to the list of tracked urls.
* get_tracked_urls() - returns the list of tracked urls.
"""
tracked_urls = []
original_http_request = restutil.http_request

def http_request(method, url, *args, **kwargs):
handler = None
if method == 'GET':
handler = http_get_handler
elif method == 'POST':
handler = http_post_handler
elif method == 'PUT':
handler = http_put_handler

if handler is not None:
return_value = handler(url, *args, **kwargs)
if return_value is not None:
tracked_urls.append(url)
if isinstance(return_value, Exception):
raise return_value
return return_value

return original_http_request(method, url, *args, **kwargs)

patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request)
patched.track_url = lambda url: tracked_urls.append(url)
patched.get_tracked_urls = lambda: tracked_urls
patched.start()
try:
yield patched
finally:
patched.stop()
Loading