diff --git a/homeassistant/components/rainmachine/__init__.py b/homeassistant/components/rainmachine/__init__.py index 2ff5ddcd4aa2ea..2594816811ed69 100644 --- a/homeassistant/components/rainmachine/__init__.py +++ b/homeassistant/components/rainmachine/__init__.py @@ -1,22 +1,20 @@ """Support for RainMachine devices.""" import logging from datetime import timedelta -from functools import wraps import voluptuous as vol -from homeassistant.auth.permissions.const import POLICY_CONTROL from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import ( ATTR_ATTRIBUTION, CONF_BINARY_SENSORS, CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_SSL, CONF_MONITORED_CONDITIONS, CONF_SWITCHES) -from homeassistant.exceptions import ( - ConfigEntryNotReady, Unauthorized, UnknownUser) +from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.service import verify_domain_control from .config_flow import configured_instances from .const import ( @@ -131,44 +129,6 @@ }, extra=vol.ALLOW_EXTRA) -def _check_valid_user(hass): - """Ensure the user of a service call has proper permissions.""" - def decorator(service): - """Decorate.""" - @wraps(service) - async def check_permissions(call): - """Check user permission and raise before call if unauthorized.""" - if not call.context.user_id: - return - - user = await hass.auth.async_get_user(call.context.user_id) - if user is None: - raise UnknownUser( - context=call.context, - permission=POLICY_CONTROL - ) - - # RainMachine services don't interact with specific entities. - # Therefore, we examine _all_ RainMachine entities and if the user - # has permission to control _any_ of them, the user has permission - # to call the service: - en_reg = await hass.helpers.entity_registry.async_get_registry() - rainmachine_entities = [ - entity.entity_id for entity in en_reg.entities.values() - if entity.platform == DOMAIN - ] - for entity_id in rainmachine_entities: - if user.permissions.check_entity(entity_id, POLICY_CONTROL): - return await service(call) - - raise Unauthorized( - context=call.context, - permission=POLICY_CONTROL, - ) - return check_permissions - return decorator - - async def async_setup(hass, config): """Set up the RainMachine component.""" hass.data[DOMAIN] = {} @@ -198,6 +158,8 @@ async def async_setup_entry(hass, config_entry): from regenmaschine import login from regenmaschine.errors import RainMachineError + _verify_domain_control = verify_domain_control(hass, DOMAIN) + websession = aiohttp_client.async_get_clientsession(hass) try: @@ -238,69 +200,69 @@ async def refresh(event_time): refresh, timedelta(seconds=config_entry.data[CONF_SCAN_INTERVAL])) - @_check_valid_user(hass) + @_verify_domain_control async def disable_program(call): """Disable a program.""" await rainmachine.client.programs.disable( call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def disable_zone(call): """Disable a zone.""" await rainmachine.client.zones.disable(call.data[CONF_ZONE_ID]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def enable_program(call): """Enable a program.""" await rainmachine.client.programs.enable(call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def enable_zone(call): """Enable a zone.""" await rainmachine.client.zones.enable(call.data[CONF_ZONE_ID]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def pause_watering(call): """Pause watering for a set number of seconds.""" await rainmachine.client.watering.pause_all(call.data[CONF_SECONDS]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def start_program(call): """Start a particular program.""" await rainmachine.client.programs.start(call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def start_zone(call): """Start a particular zone for a certain amount of time.""" await rainmachine.client.zones.start( call.data[CONF_ZONE_ID], call.data[CONF_ZONE_RUN_TIME]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def stop_all(call): """Stop all watering.""" await rainmachine.client.watering.stop_all() async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def stop_program(call): """Stop a program.""" await rainmachine.client.programs.stop(call.data[CONF_PROGRAM_ID]) async_dispatcher_send(hass, PROGRAM_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def stop_zone(call): """Stop a zone.""" await rainmachine.client.zones.stop(call.data[CONF_ZONE_ID]) async_dispatcher_send(hass, ZONE_UPDATE_TOPIC) - @_check_valid_user(hass) + @_verify_domain_control async def unpause_watering(call): """Unpause watering.""" await rainmachine.client.watering.unpause_all() diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index ea62d12c66c02e..3a9a7f0937caf7 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -7,17 +7,20 @@ import voluptuous as vol -from homeassistant.auth.permissions.const import POLICY_CONTROL +from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.const import ( ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID) import homeassistant.core as ha -from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser +from homeassistant.exceptions import ( + HomeAssistantError, TemplateError, Unauthorized, UnknownUser) from homeassistant.helpers import template, typing from homeassistant.loader import get_component, bind_hass from homeassistant.util.yaml import load_yaml import homeassistant.helpers.config_validation as cv from homeassistant.util.async_ import run_coroutine_threadsafe +from .typing import HomeAssistantType + CONF_SERVICE = 'service' CONF_SERVICE_TEMPLATE = 'service_template' CONF_SERVICE_ENTITY_ID = 'entity_id' @@ -360,3 +363,47 @@ async def admin_handler(call): hass.services.async_register( domain, service, admin_handler, schema ) + + +@bind_hass +@ha.callback +def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable: + """Ensure permission to access any entity under domain in service call.""" + def decorator(service_handler: Callable) -> Callable: + """Decorate.""" + if not asyncio.iscoroutinefunction(service_handler): + raise HomeAssistantError( + 'Can only decorate async functions.') + + async def check_permissions(call): + """Check user permission and raise before call if unauthorized.""" + if not call.context.user_id: + return await service_handler(call) + + user = await hass.auth.async_get_user(call.context.user_id) + if user is None: + raise UnknownUser( + context=call.context, + permission=POLICY_CONTROL, + user_id=call.context.user_id) + + reg = await hass.helpers.entity_registry.async_get_registry() + entities = [ + entity.entity_id for entity in reg.entities.values() + if entity.platform == domain + ] + + for entity_id in entities: + if user.permissions.check_entity(entity_id, POLICY_CONTROL): + return await service_handler(call) + + raise Unauthorized( + context=call.context, + permission=POLICY_CONTROL, + user_id=call.context.user_id, + perm_category=CAT_ENTITIES + ) + + return check_permissions + + return decorator diff --git a/tests/components/rainmachine/conftest.py b/tests/components/rainmachine/conftest.py deleted file mode 100644 index fdc81151995601..00000000000000 --- a/tests/components/rainmachine/conftest.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Configuration for Rainmachine tests.""" -import pytest - -from homeassistant.components.rainmachine.const import DOMAIN -from homeassistant.const import ( - CONF_IP_ADDRESS, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SSL) - -from tests.common import MockConfigEntry - - -@pytest.fixture(name="config_entry") -def config_entry_fixture(): - """Create a mock RainMachine config entry.""" - return MockConfigEntry( - domain=DOMAIN, - title='192.168.1.101', - data={ - CONF_IP_ADDRESS: '192.168.1.101', - CONF_PASSWORD: '12345', - CONF_PORT: 8080, - CONF_SSL: True, - CONF_SCAN_INTERVAL: 60, - }) diff --git a/tests/components/rainmachine/test_service_permissions.py b/tests/components/rainmachine/test_service_permissions.py deleted file mode 100644 index caa84337517c00..00000000000000 --- a/tests/components/rainmachine/test_service_permissions.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Define tests for permissions on RainMachine service calls.""" -import asynctest -import pytest - -from homeassistant.components.rainmachine.const import DOMAIN -from homeassistant.core import Context -from homeassistant.exceptions import Unauthorized, UnknownUser -from homeassistant.setup import async_setup_component - -from tests.common import mock_coro - - -async def setup_platform(hass, config_entry): - """Set up the media player platform for testing.""" - with asynctest.mock.patch('regenmaschine.login') as mock_login: - mock_client = mock_login.return_value - mock_client.restrictions.current.return_value = mock_coro() - mock_client.restrictions.universal.return_value = mock_coro() - config_entry.add_to_hass(hass) - assert await async_setup_component(hass, DOMAIN) - await hass.async_block_till_done() - - -async def test_services_authorization( - hass, config_entry, hass_read_only_user): - """Test that a RainMachine service is halted on incorrect permissions.""" - await setup_platform(hass, config_entry) - - with pytest.raises(UnknownUser): - await hass.services.async_call( - 'rainmachine', - 'unpause_watering', {}, - blocking=True, - context=Context(user_id='fake_user_id')) - - with pytest.raises(Unauthorized): - await hass.services.async_call( - 'rainmachine', - 'unpause_watering', {}, - blocking=True, - context=Context(user_id=hass_read_only_user.id)) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 231ffddff3095c..e6f4b15457e455 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -38,12 +38,14 @@ def mock_entities(): available=True, should_poll=False, supported_features=1, + platform='test_domain', ) living_room = Mock( entity_id='light.living_room', available=True, should_poll=False, supported_features=0, + platform='test_domain', ) entities = OrderedDict() entities[kitchen.entity_id] = kitchen @@ -461,3 +463,116 @@ async def mock_service(call): )) assert len(calls) == 1 assert calls[0].context.user_id == hass_admin_user.id + + +async def test_domain_control_not_async(hass, mock_entities): + """Test domain verification in a service call with an unknown user.""" + calls = [] + + def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with pytest.raises(exceptions.HomeAssistantError): + hass.helpers.service.verify_domain_control( + 'test_domain')(mock_service_log) + + +async def test_domain_control_unknown(hass, mock_entities): + """Test domain verification in a service call with an unknown user.""" + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with patch('homeassistant.helpers.entity_registry.async_get_registry', + return_value=mock_coro(Mock(entities=mock_entities))): + protected_mock_service = hass.helpers.service.verify_domain_control( + 'test_domain')(mock_service_log) + + hass.services.async_register( + 'test_domain', 'test_service', protected_mock_service, schema=None) + + with pytest.raises(exceptions.UnknownUser): + await hass.services.async_call( + 'test_domain', + 'test_service', {}, + blocking=True, + context=ha.Context(user_id='fake_user_id')) + assert len(calls) == 0 + + +async def test_domain_control_unauthorized( + hass, hass_read_only_user, mock_entities): + """Test domain verification in a service call with an unauthorized user.""" + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with patch('homeassistant.helpers.entity_registry.async_get_registry', + return_value=mock_coro(Mock(entities=mock_entities))): + protected_mock_service = hass.helpers.service.verify_domain_control( + 'test_domain')(mock_service_log) + + hass.services.async_register( + 'test_domain', 'test_service', protected_mock_service, schema=None) + + with pytest.raises(exceptions.Unauthorized): + await hass.services.async_call( + 'test_domain', + 'test_service', {}, + blocking=True, + context=ha.Context(user_id=hass_read_only_user.id)) + + +async def test_domain_control_admin(hass, hass_admin_user, mock_entities): + """Test domain verification in a service call with an admin user.""" + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with patch('homeassistant.helpers.entity_registry.async_get_registry', + return_value=mock_coro(Mock(entities=mock_entities))): + protected_mock_service = hass.helpers.service.verify_domain_control( + 'test_domain')(mock_service_log) + + hass.services.async_register( + 'test_domain', 'test_service', protected_mock_service, schema=None) + + await hass.services.async_call( + 'test_domain', + 'test_service', {}, + blocking=True, + context=ha.Context(user_id=hass_admin_user.id)) + + assert len(calls) == 1 + + +async def test_domain_control_no_user(hass, mock_entities): + """Test domain verification in a service call with no user.""" + calls = [] + + async def mock_service_log(call): + """Define a protected service.""" + calls.append(call) + + with patch('homeassistant.helpers.entity_registry.async_get_registry', + return_value=mock_coro(Mock(entities=mock_entities))): + protected_mock_service = hass.helpers.service.verify_domain_control( + 'test_domain')(mock_service_log) + + hass.services.async_register( + 'test_domain', 'test_service', protected_mock_service, schema=None) + + await hass.services.async_call( + 'test_domain', + 'test_service', {}, + blocking=True, + context=ha.Context(user_id=None)) + + assert len(calls) == 1