Skip to content

Commit

Permalink
[druid] Updating refresh logic (apache#4655)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Mar 27, 2018
1 parent 52b925f commit f9d85bd
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 55 deletions.
7 changes: 3 additions & 4 deletions superset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,9 @@ def load_examples(load_test_data):
)
@manager.option(
'-m', '--merge',
help=(
"Specify using 'merge' property during operation. "
'Default value is False '
),
action='store_true',
help="Specify using 'merge' property during operation.",
default=False,
)
def refresh_druid(datasource, merge):
"""Refresh druid datasources"""
Expand Down
67 changes: 27 additions & 40 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from six import string_types
import sqlalchemy as sa
from sqlalchemy import (
Boolean, Column, DateTime, ForeignKey, Integer, or_, String, Text, UniqueConstraint,
Boolean, Column, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint,
)
from sqlalchemy.orm import backref, relationship

Expand Down Expand Up @@ -200,33 +200,31 @@ def refresh(self, datasource_names, merge_flag, refreshAll):
col_objs_list = (
session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(or_(DruidColumn.column_name == col for col in cols))
.filter(DruidColumn.column_name.in_(cols.keys()))
)
col_objs = {col.column_name: col for col in col_objs_list}
for col in cols:
if col == '__time': # skip the time column
continue
col_obj = col_objs.get(col, None)
col_obj = col_objs.get(col)
if not col_obj:
col_obj = DruidColumn(
datasource_id=datasource.id,
column_name=col)
with session.no_autoflush:
session.add(col_obj)
datatype = cols[col]['type']
if datatype == 'STRING':
col_obj.type = cols[col]['type']
col_obj.datasource = datasource
if col_obj.type == 'STRING':
col_obj.groupby = True
col_obj.filterable = True
if datatype == 'hyperUnique' or datatype == 'thetaSketch':
if col_obj.type == 'hyperUnique' or col_obj.type == 'thetaSketch':
col_obj.count_distinct = True
# Allow sum/min/max for long or double
if datatype == 'LONG' or datatype == 'DOUBLE':
if col_obj.is_num:
col_obj.sum = True
col_obj.min = True
col_obj.max = True
col_obj.type = datatype
col_obj.datasource = datasource
datasource.generate_metrics_for(col_objs_list)
datasource.refresh_metrics()
session.commit()

@property
Expand Down Expand Up @@ -361,21 +359,24 @@ def get_metrics(self):
)
return metrics

def generate_metrics(self):
"""Generate metrics based on the column metadata"""
def refresh_metrics(self):
"""Refresh metrics based on the column metadata"""
metrics = self.get_metrics()
dbmetrics = (
db.session.query(DruidMetric)
.filter(DruidMetric.datasource_id == self.datasource_id)
.filter(or_(
DruidMetric.metric_name == m for m in metrics
))
.filter(DruidMetric.metric_name.in_(metrics.keys()))
)
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics.values():
metric.datasource_id = self.datasource_id
if not dbmetrics.get(metric.metric_name, None):
db.session.add(metric)
dbmetric = dbmetrics.get(metric.metric_name)
if dbmetric:
for attr in ['json', 'metric_type', 'verbose_name']:
setattr(dbmetric, attr, getattr(metric, attr))
else:
with db.session.no_autoflush:
metric.datasource_id = self.datasource_id
db.session.add(metric)

@classmethod
def import_obj(cls, i_column):
Expand Down Expand Up @@ -653,24 +654,9 @@ def latest_metadata(self):
if segment_metadata:
return segment_metadata[-1]['columns']

def generate_metrics(self):
self.generate_metrics_for(self.columns)

def generate_metrics_for(self, columns):
metrics = {}
for col in columns:
metrics.update(col.get_metrics())
dbmetrics = (
db.session.query(DruidMetric)
.filter(DruidMetric.datasource_id == self.id)
.filter(or_(DruidMetric.metric_name == m for m in metrics))
)
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics.values():
metric.datasource_id = self.id
if not dbmetrics.get(metric.metric_name, None):
with db.session.no_autoflush:
db.session.add(metric)
def refresh_metrics(self):
for col in self.columns:
col.refresh_metrics()

@classmethod
def sync_to_db_from_config(
Expand Down Expand Up @@ -703,7 +689,7 @@ def sync_to_db_from_config(
col_objs = (
session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(or_(DruidColumn.column_name == dim for dim in dimensions))
.filter(DruidColumn.column_name.in_(dimensions))
)
col_objs = {col.column_name: col for col in col_objs}
for dim in dimensions:
Expand All @@ -723,8 +709,9 @@ def sync_to_db_from_config(
metric_objs = (
session.query(DruidMetric)
.filter(DruidMetric.datasource_id == datasource.id)
.filter(or_(DruidMetric.metric_name == spec['name']
for spec in druid_config['metrics_spec']))
.filter(DruidMetric.metric_name.in_(
spec['name'] for spec in druid_config['metrics_spec']
))
)
metric_objs = {metric.metric_name: metric for metric in metric_objs}
for metric_spec in druid_config['metrics_spec']:
Expand Down
4 changes: 2 additions & 2 deletions superset/connectors/druid/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def pre_update(self, col):
.format(dimension_spec['outputName'], col.column_name))

def post_update(self, col):
col.generate_metrics()
col.refresh_metrics()

def post_add(self, col):
self.post_update(col)
Expand Down Expand Up @@ -277,7 +277,7 @@ def pre_add(self, datasource):
datasource.full_name))

def post_add(self, datasource):
datasource.generate_metrics()
datasource.refresh_metrics()
security.merge_perm(sm, 'datasource_access', datasource.get_perm())
if datasource.schema:
security.merge_perm(sm, 'schema_access', datasource.schema_perm)
Expand Down
72 changes: 72 additions & 0 deletions superset/migrations/versions/f231d82b9b26_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""empty message
Revision ID: f231d82b9b26
Revises: e68c4473c581
Create Date: 2018-03-20 19:47:54.991259
"""

# revision identifiers, used by Alembic.
revision = 'f231d82b9b26'
down_revision = 'e68c4473c581'

from alembic import op
import sqlalchemy as sa
from sqlalchemy.exc import OperationalError

from superset.utils import generic_find_uq_constraint_name

conv = {
'uq': 'uq_%(table_name)s_%(column_0_name)s',
}

names = {
'columns': 'column_name',
'metrics': 'metric_name',
}

bind = op.get_bind()
insp = sa.engine.reflection.Inspector.from_engine(bind)


def upgrade():

# Reduce the size of the metric_name column for constraint viability.
with op.batch_alter_table('metrics', naming_convention=conv) as batch_op:
batch_op.alter_column(
'metric_name',
existing_type=sa.String(length=512),
type_=sa.String(length=255),
existing_nullable=True,
)

# Add the missing uniqueness constraints.
for table, column in names.items():
with op.batch_alter_table(table, naming_convention=conv) as batch_op:
batch_op.create_unique_constraint(
'uq_{}_{}'.format(table, column),
[column, 'datasource_id'],
)

def downgrade():

# Restore the size of the metric_name column.
with op.batch_alter_table('metrics', naming_convention=conv) as batch_op:
batch_op.alter_column(
'metric_name',
existing_type=sa.String(length=255),
type_=sa.String(length=512),
existing_nullable=True,
)

# Remove the previous missing uniqueness constraints.
for table, column in names.items():
with op.batch_alter_table(table, naming_convention=conv) as batch_op:
batch_op.drop_constraint(
generic_find_uq_constraint_name(
table,
{column, 'datasource_id'},
insp,
) or 'uq_{}_{}'.format(table, column),
type_='unique',
)
87 changes: 78 additions & 9 deletions tests/druid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from superset import db, security, sm
from superset.connectors.druid.models import (
DruidCluster, DruidDatasource,
DruidCluster, DruidColumn, DruidDatasource, DruidMetric,
)
from .base_tests import SupersetTestCase

Expand All @@ -29,22 +29,27 @@ def __reduce__(self):
'columns': {
'__time': {
'type': 'LONG', 'hasMultipleValues': False,
'size': 407240380, 'cardinality': None, 'errorMessage': None},
'size': 407240380, 'cardinality': None, 'errorMessage': None,
},
'dim1': {
'type': 'STRING', 'hasMultipleValues': False,
'size': 100000, 'cardinality': 1944, 'errorMessage': None},
'size': 100000, 'cardinality': 1944, 'errorMessage': None,
},
'dim2': {
'type': 'STRING', 'hasMultipleValues': True,
'size': 100000, 'cardinality': 1504, 'errorMessage': None},
'size': 100000, 'cardinality': 1504, 'errorMessage': None,
},
'metric1': {
'type': 'FLOAT', 'hasMultipleValues': False,
'size': 100000, 'cardinality': None, 'errorMessage': None},
'size': 100000, 'cardinality': None, 'errorMessage': None,
},
},
'aggregators': {
'metric1': {
'type': 'longSum',
'name': 'metric1',
'fieldName': 'metric1'},
'fieldName': 'metric1',
},
},
'size': 300000,
'numRows': 5000000,
Expand Down Expand Up @@ -87,9 +92,7 @@ def get_test_cluster_obj(self):
broker_port=7980,
metadata_last_refreshed=datetime.now())

@patch('superset.connectors.druid.models.PyDruid')
def test_client(self, PyDruid):
self.login(username='admin')
def get_cluster(self, PyDruid):
instance = PyDruid.return_value
instance.time_boundary.return_value = [
{'result': {'maxTime': '2016-01-01'}}]
Expand All @@ -110,6 +113,13 @@ def test_client(self, PyDruid):
db.session.add(cluster)
cluster.get_datasources = PickableMock(return_value=['test_datasource'])
cluster.get_druid_version = PickableMock(return_value='0.9.1')

return cluster

@patch('superset.connectors.druid.models.PyDruid')
def test_client(self, PyDruid):
self.login(username='admin')
cluster = self.get_cluster(PyDruid)
cluster.refresh_datasources()
cluster.refresh_datasources(merge_flag=True)
datasource_id = cluster.datasources[0].id
Expand All @@ -121,6 +131,7 @@ def test_client(self, PyDruid):
nres = [dict(v) for v in nres]
import pandas as pd
df = pd.DataFrame(nres)
instance = PyDruid.return_value
instance.export_pandas.return_value = df
instance.query_dict = {}
instance.query_builder.last_query.query_dict = {}
Expand Down Expand Up @@ -327,6 +338,64 @@ def test_sync_druid_perm(self, PyDruid):
permission=permission, view_menu=view_menu).first()
assert pv is not None

@patch('superset.connectors.druid.models.PyDruid')
def test_refresh_metadata(self, PyDruid):
self.login(username='admin')
cluster = self.get_cluster(PyDruid)
cluster.refresh_datasources()

for i, datasource in enumerate(cluster.datasources):
cols = (
db.session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
)

for col in cols:
self.assertIn(
col.column_name,
SEGMENT_METADATA[i]['columns'].keys(),
)

metrics = (
db.session.query(DruidMetric)
.filter(DruidMetric.datasource_id == datasource.id)
.filter(DruidMetric.metric_name.like('%__metric1'))
)

self.assertEqual(
{metric.metric_name for metric in metrics},
{'max__metric1', 'min__metric1', 'sum__metric1'},
)

for metric in metrics:
agg, _ = metric.metric_name.split('__')

self.assertEqual(
json.loads(metric.json)['type'],
'double{}'.format(agg.capitalize()),
)

# Augment a metric.
metadata = SEGMENT_METADATA[:]
metadata[0]['columns']['metric1']['type'] = 'LONG'
instance = PyDruid.return_value
instance.segment_metadata.return_value = metadata
cluster.refresh_datasources()

metrics = (
db.session.query(DruidMetric)
.filter(DruidMetric.datasource_id == datasource.id)
.filter(DruidMetric.metric_name.like('%__metric1'))
)

for metric in metrics:
agg, _ = metric.metric_name.split('__')

self.assertEqual(
metric.json_obj['type'],
'long{}'.format(agg.capitalize()),
)

def test_urls(self):
cluster = self.get_test_cluster_obj()
self.assertEquals(
Expand Down

0 comments on commit f9d85bd

Please sign in to comment.