diff --git a/caravel/utils.py b/caravel/utils.py index 5f2a76009204e..15c0d75cd7ffa 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -18,6 +18,14 @@ from sqlalchemy.types import TypeDecorator, TEXT +class CaravelException(Exception): + pass + + +class CaravelSecurityException(CaravelException): + pass + + def flasher(msg, severity=None): """Flask's flash if available, logging call if not""" try: diff --git a/caravel/views.py b/caravel/views.py index dd351d066fe02..d5d4140b3a686 100644 --- a/caravel/views.py +++ b/caravel/views.py @@ -35,6 +35,35 @@ log_this = models.Log.log_this +def check_ownership(obj, raise_if_false=True): + """Meant to be used in `pre_update` hooks on models to enforce ownership + + Admin have all access, and other users need to be referenced on either + the created_by field that comes with the ``AuditMixin``, or in a field + named ``owners`` which is expected to be a one-to-many with the User + model. It is meant to be used in the ModelView's pre_update hook in + which raising will abort the update. + """ + roles = (r.name for r in get_user_roles()) + if 'Admin' in roles: + return True + session = db.create_scoped_session() + orig_obj = session.query(obj.__class__).filter_by(id=obj.id).first() + owner_names = (user.username for user in orig_obj.owners) + if ( + hasattr(orig_obj, 'created_by') and + orig_obj.created_by and + orig_obj.created_by.username == g.user.username): + return True + if hasattr(orig_obj, 'owners') and g.user.username in owner_names: + return True + if raise_if_false: + raise utils.CaravelSecurityException( + "You don't have the rights to alter [{}]".format(obj)) + else: + return False + + def get_user_roles(): if g.user.is_anonymous(): return [appbuilder.sm.find_role('Public')] @@ -56,7 +85,6 @@ def apply(self, query, func): # noqa if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): return query qry = query.filter(self.model.perm.in_(self.get_perms())) - print(qry) return qry @@ -65,16 +93,23 @@ def apply(self, query, func): # noqa if any([r.name in ('Admin', 'Alpha') for r in get_user_roles()]): return query Slice = models.Slice # noqa + Dash = models.Dashboard # noqa slice_ids_qry = ( db.session .query(Slice.id) .filter(Slice.perm.in_(self.get_perms())) ) - return query.filter( - self.model.slices.any( - models.Slice.id.in_(slice_ids_qry) + print([r for r in slice_ids_qry.all()]) + query = query.filter( + Dash.id.in_( + db.session.query(Dash.id) + .distinct() + .join(Dash.slices) + .filter(Slice.id.in_(slice_ids_qry)) ) ) + print(query) + return query def validate_json(form, field): # noqa @@ -407,6 +442,9 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa 'viz_type': _("Visualization Type"), } + def pre_update(self, obj): + check_ownership(obj) + appbuilder.add_view( SliceModelView, __("Slices"), @@ -470,6 +508,7 @@ def pre_add(self, obj): obj.slug = re.sub(r'\W+', '', obj.slug) def pre_update(self, obj): + check_ownership() self.pre_add(obj) @@ -762,11 +801,15 @@ def save_slice(self, slc): flash(msg, "info") def overwrite_slice(self, slc): - session = db.session() - msg = "Slice [{}] has been overwritten".format(slc.slice_name) - session.merge(slc) - session.commit() - flash(msg, "info") + can_update = check_ownership(slc, raise_if_false=False) + if not can_update: + flash("You cannot overwrite [{}]".format(slc)) + else: + session = db.session() + session.merge(slc) + session.commit() + msg = "Slice [{}] has been overwritten".format(slc.slice_name) + flash(msg, "info") @has_access @expose("/checkbox////", methods=['GET']) @@ -810,6 +853,7 @@ def save_dash(self, dashboard_id): session = db.session() Dash = models.Dashboard # noqa dash = session.query(Dash).filter_by(id=dashboard_id).first() + check_ownership(dash, raise_if_false=True) dash.slices = [o for o in dash.slices if o.id in slice_ids] dash.position_json = json.dumps(data['positions'], indent=4) md = dash.metadata_dejson @@ -961,7 +1005,7 @@ def runsql(self): if ( not self.appbuilder.sm.has_access( 'all_datasource_access', 'all_datasource_access')): - raise Exception(_( + raise utils.CaravelSecurityException(_( "This view requires the `all_datasource_access` permission")) content = "" if mydb: diff --git a/tests/core_tests.py b/tests/core_tests.py index a6528af9a2fa8..b450c3e81a020 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -6,6 +6,7 @@ from datetime import datetime import doctest +import json import imp import os import unittest @@ -36,6 +37,7 @@ def __init__(self, *args, **kwargs): self.client = app.test_client() utils.init(caravel) + admin = appbuilder.sm.find_user('admin') if not admin: appbuilder.sm.add_user( @@ -49,30 +51,42 @@ def __init__(self, *args, **kwargs): 'gamma', 'gamma', 'user', 'gamma@fab.org', appbuilder.sm.find_role('Gamma'), password='general') + + alpha = appbuilder.sm.find_user('alpha') + if not alpha: + appbuilder.sm.add_user( + 'alpha', 'alpha', 'user', 'alpha@fab.org', + appbuilder.sm.find_role('Alpha'), + password='general') + utils.init(caravel) - def login_admin(self): + def login(self, username='admin', password='general'): resp = self.client.post( '/login/', - data=dict(username='admin', password='general'), + data=dict(username=username, password=password), follow_redirects=True) assert 'Welcome' in resp.data.decode('utf-8') - def login_gamma(self): - resp = self.client.post( - '/login/', - data=dict(username='gamma', password='general'), - follow_redirects=True) - assert 'Welcome' in resp.data.decode('utf-8') + def logout(self): + resp = self.client.get('/logout/', follow_redirects=True) - def setup_public_access_for_dashboard(self, dashboard_name): + def setup_public_access_for_dashboard(self, table_name): public_role = appbuilder.sm.find_role('Public') perms = db.session.query(ab_models.PermissionView).all() for perm in perms: - if (perm.permission.name == 'datasource_access' and - perm.view_menu and dashboard_name in perm.view_menu.name): + if ( perm.permission.name == 'datasource_access' and + perm.view_menu and table_name in perm.view_menu.name): appbuilder.sm.add_permission_role(public_role, perm) + def revoke_public_access(self, table_name): + public_role = appbuilder.sm.find_role('Public') + perms = db.session.query(ab_models.PermissionView).all() + for perm in perms: + if ( perm.permission.name == 'datasource_access' and + perm.view_menu and table_name in perm.view_menu.name): + appbuilder.sm.del_permission_role(public_role, perm) + class CoreTests(CaravelTestCase): @@ -97,7 +111,7 @@ def load_examples(self): cli.load_examples(load_test_data=True) def test_save_slice(self): - self.login_admin() + self.login(username='admin') slice_id = ( db.session.query(models.Slice.id) @@ -120,7 +134,7 @@ def test_save_slice(self): def test_slices(self): # Testing by running all the examples - self.login_admin() + self.login(username='admin') Slc = models.Slice urls = [] for slc in db.session.query(Slc).all(): @@ -134,7 +148,7 @@ def test_slices(self): self.client.get(url) def test_dashboard(self): - self.login_admin() + self.login(username='admin') urls = {} for dash in db.session.query(models.Dashboard).all(): urls[dash.dashboard_title] = dash.url @@ -153,23 +167,35 @@ def test_misc(self): assert self.client.get('/ping').data.decode('utf-8') == "OK" def test_shortner(self): - self.login_admin() + self.login(username='admin') data = "//caravel/explore/table/1/?viz_type=sankey&groupby=source&groupby=target&metric=sum__value&row_limit=5000&where=&having=&flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id=78&slice_name=Energy+Sankey&collapsed_fieldsets=&action=&datasource_name=energy_usage&datasource_id=1&datasource_type=table&previous_viz_type=sankey" resp = self.client.post('/r/shortner/', data=data) assert '/r/' in resp.data.decode('utf-8') - def test_save_dash(self): - self.login_admin() + def test_save_dash(self, username='admin'): + self.login(username=username) dash = db.session.query(models.Dashboard).filter_by(slug="births").first() - data = """{"positions":[{"slice_id":"131","col":8,"row":8,"size_x":2,"size_y":4},{"slice_id":"132","col":10,"row":8,"size_x":2,"size_y":4},{"slice_id":"133","col":1,"row":1,"size_x":2,"size_y":2},{"slice_id":"134","col":3,"row":1,"size_x":2,"size_y":2},{"slice_id":"135","col":5,"row":4,"size_x":3,"size_y":3},{"slice_id":"136","col":1,"row":7,"size_x":7,"size_y":4},{"slice_id":"137","col":9,"row":1,"size_x":3,"size_y":3},{"slice_id":"138","col":5,"row":1,"size_x":4,"size_y":3},{"slice_id":"139","col":1,"row":3,"size_x":4,"size_y":4},{"slice_id":"140","col":8,"row":4,"size_x":4,"size_y":4}],"css":"None","expanded_slices":{}}""" + positions = [] + for i, slc in enumerate(dash.slices): + d = { + 'col': 0, + 'row': i * 4, + 'size_x': 4, + 'size_y': 4, + 'slice_id': '{}'.format(slc.id)} + positions.append(d) + data = { + 'css': '', + 'expanded_slices': {}, + 'positions': positions, + } url = '/caravel/save_dash/{}/'.format(dash.id) - resp = self.client.post(url, data=dict(data=data)) + resp = self.client.post(url, data=dict(data=json.dumps(data))) assert "SUCCESS" in resp.data.decode('utf-8') def test_gamma(self): - self.login_gamma() + self.login(username='gamma') resp = self.client.get('/slicemodelview/list/') - print(resp.data.decode('utf-8')) assert "List Slice" in resp.data.decode('utf-8') resp = self.client.get('/dashboardmodelview/list/') @@ -177,50 +203,67 @@ def test_gamma(self): def test_public_user_dashboard_access(self): # Try access before adding appropriate permissions. + self.revoke_public_access('birth_names') + self.logout() + resp = self.client.get('/slicemodelview/list/') data = resp.data.decode('utf-8') - assert 'birth_names' not in data - resp = self.client.get('/dashboardmodelview/list/') - data = resp.data.decode('utf-8') - assert '' not in data + assert 'birth_names' not in data - resp = self.client.get('/caravel/explore/table/3/', follow_redirects=True) + resp = self.client.get('/dashboardmodelview/list/') data = resp.data.decode('utf-8') - assert "You don't seem to have access to this datasource" in data + assert '/caravel/dashboard/births/' not in data self.setup_public_access_for_dashboard('birth_names') # Try access after adding appropriate permissions. resp = self.client.get('/slicemodelview/list/') data = resp.data.decode('utf-8') - assert 'birth_names' in data + assert 'birth_names' in data resp = self.client.get('/dashboardmodelview/list/') data = resp.data.decode('utf-8') - assert '' in data + assert "/caravel/dashboard/births/" in data resp = self.client.get('/caravel/dashboard/births/') data = resp.data.decode('utf-8') - assert '[dashboard] Births' in data - - resp = self.client.get('/caravel/explore/table/3/') - data = resp.data.decode('utf-8') - assert '[explore] birth_names' in data + assert 'Births' in data # Confirm that public doesn't have access to other datasets. - resp = self.client.get('/slicemodelview/list/') - data = resp.data.decode('utf-8') - assert 'wb_health_population' not in data - resp = self.client.get('/dashboardmodelview/list/') data = resp.data.decode('utf-8') - assert '' not in data + assert "/caravel/dashboard/world_health/" not in data - resp = self.client.get('/caravel/explore/table/2/', follow_redirects=True) - data = resp.data.decode('utf-8') - assert "You don't seem to have access to this datasource" in data + def test_only_owners_can_save(self): + dash = ( + db.session + .query(models.Dashboard) + .filter_by(slug="births") + .first() + ) + dash.owners = [] + db.session.merge(dash) + db.session.commit() + self.test_save_dash('admin') + + self.logout() + self.assertRaises( + utils.CaravelSecurityException, self.test_save_dash, 'alpha') + + alpha = appbuilder.sm.find_user('alpha') + + dash = ( + db.session + .query(models.Dashboard) + .filter_by(slug="births") + .first() + ) + dash.owners = [alpha] + db.session.merge(dash) + db.session.commit() + self.test_save_dash('alpha') SEGMENT_METADATA = [{ "id": "some_id", @@ -278,7 +321,7 @@ def __init__(self, *args, **kwargs): @patch('caravel.models.PyDruid') def test_client(self, PyDruid): - self.login_admin() + self.login(username='admin') instance = PyDruid.return_value instance.time_boundary.return_value = [ {'result': {'maxTime': '2016-01-01'}}] @@ -321,8 +364,6 @@ def test_client(self, PyDruid): instance.query_dict = {} instance.query_builder.last_query.query_dict = {} resp = self.client.get('/caravel/explore/druid/1/?viz_type=table&granularity=one+day&druid_time_origin=&since=7+days+ago&until=now&row_limit=5000&include_search=false&metrics=count&groupby=name&flt_col_0=dim1&flt_op_0=in&flt_eq_0=&slice_id=&slice_name=&collapsed_fieldsets=&action=&datasource_name=test_datasource&datasource_id=1&datasource_type=druid&previous_viz_type=table&json=true&force=true') - print('-'*300) - print(resp.data.decode('utf-8')) assert "Canada" in resp.data.decode('utf-8')