diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index efd37babfc7e9..6d1be23a47466 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -133,13 +133,14 @@ def _allowed_file(filename): 'table': table, 'df': df, 'name': form.name.data, - 'con': create_engine(form.con.data, echo=False), + 'con': create_engine(form.con.data.sqlalchemy_uri, echo=False), 'schema': form.schema.data, 'if_exists': form.if_exists.data, 'index': form.index.data, 'index_label': form.index_label.data, 'chunksize': 10000, } + BaseEngineSpec.df_to_db(**df_to_db_kwargs) @classmethod diff --git a/superset/forms.py b/superset/forms.py index a07790440ffde..cacb9067eb81b 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -10,14 +10,20 @@ from flask_wtf.file import FileAllowed, FileField, FileRequired from wtforms import ( BooleanField, IntegerField, SelectField, StringField) +from wtforms.ext.sqlalchemy.fields import QuerySelectField from wtforms.validators import DataRequired, NumberRange, Optional -from superset import app +from superset import app, db +from superset.models import core as models config = app.config class CsvToDatabaseForm(DynamicForm): + # pylint: disable=E0211 + def all_db_items(): + return db.session.query(models.Database) + name = StringField( _('Table Name'), description=_('Name of table to be created from csv data.'), @@ -28,12 +34,9 @@ class CsvToDatabaseForm(DynamicForm): description=_('Select a CSV file to be uploaded to a database.'), validators=[ FileRequired(), FileAllowed(['csv'], _('CSV Files Only!'))]) - - con = SelectField( - _('Database'), - description=_('database in which to add above table.'), - validators=[DataRequired()], - choices=[]) + con = QuerySelectField( + query_factory=all_db_items, + get_pk=lambda a: a.id, get_label=lambda a: a.database_name) sep = StringField( _('Delimiter'), description=_('Delimiter used by CSV file (for whitespace use \s+).'), @@ -49,7 +52,6 @@ class CsvToDatabaseForm(DynamicForm): ('fail', _('Fail')), ('replace', _('Replace')), ('append', _('Append'))], validators=[DataRequired()]) - schema = StringField( _('Schema'), description=_('Specify a schema (if database flavour supports this).'), diff --git a/superset/views/core.py b/superset/views/core.py index 0f0adea149ede..1b9739fa37c13 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -24,10 +24,11 @@ from flask_babel import gettext as __ from flask_babel import lazy_gettext as _ import pandas as pd +from six import text_type import sqlalchemy as sqla from sqlalchemy import create_engine from sqlalchemy.engine.url import make_url -from sqlalchemy.exc import OperationalError +from sqlalchemy.exc import IntegrityError, OperationalError from unidecode import unidecode from werkzeug.routing import BaseConverter from werkzeug.utils import secure_filename @@ -163,8 +164,6 @@ def apply(self, query, func): # noqa return query - - class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa datamodel = SQLAInterface(models.Database) @@ -319,49 +318,36 @@ def form_get(self, form): form.infer_datetime_format.data = True form.decimal.data = '.' form.if_exists.data = 'append' - all_datasources = ( - db.session.query( - models.Database.sqlalchemy_uri, - models.Database.database_name) - .all() - ) - form.con.choices += all_datasources def form_post(self, form): - def _upload_file(csv_file): - if csv_file and csv_file.filename: - filename = secure_filename(csv_file.filename) - csv_file.save(os.path.join(config['UPLOAD_FOLDER'], filename)) - return filename - csv_file = form.csv_file.data - _upload_file(csv_file) - table = SqlaTable(table_name=form.name.data) - database = ( - db.session.query(models.Database) - .filter_by(sqlalchemy_uri=form.data.get('con')) - .one() - ) - table.database = database - table.database_id = database.id + form.csv_file.data.filename = secure_filename(form.csv_file.data.filename) + csv_filename = form.csv_file.data.filename try: - database.db_engine_spec.create_table_from_csv(form, table) + csv_file.save(os.path.join(config['UPLOAD_FOLDER'], csv_filename)) + table = SqlaTable(table_name=form.name.data) + table.database = form.data.get('con') + table.database_id = table.database.id + table.database.db_engine_spec.create_table_from_csv(form, table) except Exception as e: - os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_file.filename)) - flash(e, 'error') - return redirect('/tablemodelview/list/') + try: + os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_filename)) + except OSError: + pass + message = u'Table name {} already exists. Please pick another'.format( + form.name.data) if isinstance(e, IntegrityError) else text_type(e) + flash( + message, + 'danger') + return redirect('/csvtodatabaseview/form') - os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_file.filename)) + os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_filename)) # Go back to welcome page / splash screen - db_name = ( - db.session.query(models.Database.database_name) - .filter_by(sqlalchemy_uri=form.data.get('con')) - .one() - ) - message = _('CSV file "{0}" uploaded to table "{1}" in ' - 'database "{2}"'.format(form.csv_file.data.filename, + db_name = table.database.database_name + message = _(u'CSV file "{0}" uploaded to table "{1}" in ' + 'database "{2}"'.format(csv_filename, form.name.data, - db_name[0])) + db_name)) flash(message, 'info') return redirect('/tablemodelview/list/') diff --git a/tests/core_tests.py b/tests/core_tests.py index 367ab68950f44..f6eb94d3f211d 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -815,20 +815,22 @@ def test_import_csv(self): test_file.write('john,1\n') test_file.write('paul,2\n') test_file.close() - main_db_uri = db.session.query( - models.Database.sqlalchemy_uri)\ - .filter_by(database_name='main').all() + main_db_uri = ( + db.session.query(models.Database) + .filter_by(database_name='main') + .all() + ) test_file = open(filename, 'rb') form_data = { 'csv_file': test_file, 'sep': ',', 'name': table_name, - 'con': main_db_uri[0][0], + 'con': main_db_uri[0].id, 'if_exists': 'append', 'index_label': 'test_label', - 'mangle_dupe_cols': False} - + 'mangle_dupe_cols': False, + } url = '/databaseview/list/' add_datasource_page = self.get_resp(url) assert 'Upload a CSV' in add_datasource_page