Skip to content

Commit

Permalink
Maintain lookup tables in SQLite
Browse files Browse the repository at this point in the history
Refs #17
  • Loading branch information
simonw authored and Simon Willison committed Jan 23, 2018
1 parent 65ac5d4 commit 42ae631
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 67 deletions.
41 changes: 18 additions & 23 deletions csvs_to_sqlite/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,33 +111,28 @@ def cli(paths, dbname, separator, quoting, skip_errors, replace_tables, table, e

# Now we have loaded the dataframes, we can refactor them
created_tables = {}
refactored = refactor_dataframes(dataframes, foreign_keys)
refactored = refactor_dataframes(conn, dataframes, foreign_keys)
for df in refactored:
if isinstance(df, LookupTable):
# This is a bit trickier because we need to
# create the table with extra SQL for foreign keys
if replace_tables and table_exists(conn, df.table_name):
drop_table(conn, df.table_name)
if table_exists(conn, df.table_name):
df.to_sql(
df.table_name, conn
df.table_name,
conn,
if_exists='append',
index=False,
)
else:
# This is a bit trickier because we need to
# create the table with extra SQL for foreign keys
if replace_tables and table_exists(conn, df.table_name):
drop_table(conn, df.table_name)
if table_exists(conn, df.table_name):
df.to_sql(
df.table_name,
conn,
if_exists='append',
index=False,
)
else:
to_sql_with_foreign_keys(
conn, df, df.table_name, foreign_keys, sql_type_overrides,
index_fks=not no_index_fks
)
created_tables[df.table_name] = df
if index:
for index_defn in index:
add_index(conn, df.table_name, index_defn)
to_sql_with_foreign_keys(
conn, df, df.table_name, foreign_keys, sql_type_overrides,
index_fks=not no_index_fks
)
created_tables[df.table_name] = df
if index:
for index_defn in index:
add_index(conn, df.table_name, index_defn)

# Create FTS tables
if fts:
Expand Down
92 changes: 57 additions & 35 deletions csvs_to_sqlite/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import fnmatch
import hashlib
import lru
import pandas as pd
import numpy as np
import re
Expand Down Expand Up @@ -67,55 +68,75 @@ def add_file(filepath):


class LookupTable:
# This should probably be a pandas Series or DataFrame
def __init__(self, table_name, value_column):
def __init__(self, conn, table_name, value_column):
self.conn = conn
self.table_name = table_name
self.value_column = value_column
self.next_id = 1
self.id_to_value = {}
self.value_to_id = {}
self.cache = lru.LRUCacheDict(max_size=1000)
self.ensure_table_exists()

def ensure_table_exists(self):
if not self.conn.execute('''
SELECT name
FROM sqlite_master
WHERE type='table'
AND name=?
''', (self.table_name,)).fetchall():
create_sql = '''
CREATE TABLE "{table_name}" (
"id" INTEGER PRIMARY KEY,
"{value_column}" TEXT
);
'''.format(
table_name=self.table_name,
value_column=self.value_column,
)
self.conn.execute(create_sql)

def __repr__(self):
return '<{}: {} rows>'.format(
self.table_name, len(self.id_to_value)
self.table_name, self.conn.execute(
'select count(*) from "{}"'.format(self.table_name)
).fetchone()[0]
)

def id_for_value(self, value):
if pd.isnull(value):
return None
# value should be a string
if not isinstance(value, six.string_types):
if isinstance(value, float):
value = '{0:g}'.format(value)
else:
value = six.text_type(value)
try:
return self.value_to_id[value]
# First try our in-memory cache
return self.cache[value]
except KeyError:
id = self.next_id
self.id_to_value[id] = value
self.value_to_id[value] = id
self.next_id += 1
# Next try the database table
sql = 'SELECT id FROM "{table_name}" WHERE "{value_column}"=?'.format(
table_name=self.table_name,
value_column=self.value_column,
)
result = self.conn.execute(sql, (value,)).fetchall()
if result:
id = result[0][0]
else:
# Not in DB! Insert it
cursor = self.conn.cursor()
insert_sql = '''
INSERT INTO "{table_name}" ("{value_column}") VALUES (?);
'''.format(
table_name=self.table_name,
value_column=self.value_column,
)
cursor.execute(insert_sql, (value,))
id = cursor.lastrowid
self.cache[value] = id
return id

def to_sql(self, name, conn):
create_sql, columns = get_create_table_sql(name, pd.Series(
self.id_to_value,
name=self.value_column,
), index_label='id')
# This table does not have a primary key. Let's fix that:
before, after = create_sql.split('"id" INTEGER', 1)
create_sql = '{} "id" INTEGER PRIMARY KEY {}'.format(
before, after,
)
conn.executescript(create_sql)
# Now that we have created the table, insert the rows:
pd.Series(
self.id_to_value,
name=self.value_column,
).to_sql(
name,
conn,
if_exists='append',
index_label='id'
)


def refactor_dataframes(dataframes, foreign_keys):
def refactor_dataframes(conn, dataframes, foreign_keys):
lookup_tables = {}
for column, (table_name, value_column) in foreign_keys.items():
# Now apply this to the dataframes
Expand All @@ -124,14 +145,15 @@ def refactor_dataframes(dataframes, foreign_keys):
lookup_table = lookup_tables.get(table_name)
if lookup_table is None:
lookup_table = LookupTable(
conn=conn,
table_name=table_name,
value_column=value_column,
)
lookup_tables[table_name] = lookup_table
dataframe[column] = dataframe[column].apply(
lookup_table.id_for_value
)
return list(lookup_tables.values()) + dataframes
return dataframes


def table_exists(conn, table):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
install_requires=[
'click==6.7',
'pandas==0.20.3',
'py-lru-cache==0.1.4',
'six',
],
setup_requires=['pytest-runner'],
Expand Down
10 changes: 5 additions & 5 deletions tests/test_csvs_to_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def test_extract_columns():
('Yolo', 100001, 'President', None, 'PAF', 'Gloria Estela La Riva', 8),
('Yolo', 100001, 'Proposition 51', None, None, 'No', 398),
('Yolo', 100001, 'Proposition 51', None, None, 'Yes', 460),
('Yolo', 100001, 'State Assembly', 7, 'DEM', 'Kevin McCarty', 572),
('Yolo', 100001, 'State Assembly', 7, 'REP', 'Ryan K. Brown', 291)
('Yolo', 100001, 'State Assembly', '7', 'DEM', 'Kevin McCarty', 572),
('Yolo', 100001, 'State Assembly', '7', 'REP', 'Ryan K. Brown', 291)
] == rows
last_row = rows[-1]
for i, t in enumerate((string_types, int, string_types, int, string_types, string_types, int)):
for i, t in enumerate((string_types, int, string_types, string_types, string_types, string_types, int)):
assert isinstance(last_row[i], t)

# Check that the various foreign key tables have the right things in them
Expand All @@ -98,7 +98,7 @@ def test_extract_columns():
(3, 'State Assembly'),
] == conn.execute('select * from office').fetchall()
assert [
(1, 7),
(1, '7'),
] == conn.execute('select * from district').fetchall()
assert [
(1, 'LIB'),
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_shape_with_extract_columns():
assert result.exit_code == 0
conn = sqlite3.connect('test.db')
assert [
('Yolo', 41, 'test'),
('Yolo', '41', 'test'),
] == conn.execute('''
select
Cty.value, Vts.value, Source.value
Expand Down
13 changes: 9 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ def test_refactor_dataframes():
'name': 'Owen',
'score': 0.7,
}])
output = utils.refactor_dataframes([df], {'name': ('People', 'first_name')})
assert 2 == len(output)
lookup_table, dataframe = output
assert {1: 'Terry', 2: 'Owen'} == lookup_table.id_to_value
conn = sqlite3.connect(':memory:')
output = utils.refactor_dataframes(conn, [df], {'name': ('People', 'first_name')})
assert 1 == len(output)
dataframe = output[0]
# There should be a 'People' table in sqlite
assert [
(1, 'Terry'),
(2, 'Owen'),
] == conn.execute('select id, first_name from People').fetchall()
assert (
' name score\n'
'0 1 0.5\n'
Expand Down

0 comments on commit 42ae631

Please sign in to comment.