Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

database: add filter to insert logic #623

Merged
merged 2 commits into from
Feb 19, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions server/php/cherrypy/src/webapp/db_pgv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,39 +559,53 @@ def _get_nextval(self, seq_name):
def _select_insert(self, table, table_id, stmt_fields, stmt_values):
found_id = -1

cursor = self.get_cursor()

cursor.execute("SELECT * FROM %s LIMIT 0" % table)
all_fields_for_table = [d[0] for d in cursor.description]

#
# Build the SELECT and INSERT statements
#
select_stmt = "\nSELECT %s FROM %s \n" % (table_id, table)
insert_stmt = "\nINSERT INTO %s \n (%s" % (table, table_id)
insert_stmt_values = []

#
# we filter because the client may have sent us stuff our database has
# no clue about
#
count = 0
for field in stmt_fields:
insert_stmt = insert_stmt + ", " + field

if count == 0:
select_stmt = select_stmt + " WHERE "
else:
select_stmt = select_stmt + " AND "
select_stmt = select_stmt + field + " = %s"
if field in all_fields_for_table:
insert_stmt = insert_stmt + ", " + field
insert_stmt_values.append(stmt_values[count])
if count == 0:
select_stmt = select_stmt + " WHERE "
else:
select_stmt = select_stmt + " AND "
select_stmt = select_stmt + field + " = %s"
count += 1

select_stmt = select_stmt + "\n ORDER BY " + table_id + " ASC LIMIT 1"

insert_stmt = insert_stmt + ") \nVALUES ("
insert_stmt = insert_stmt + " %s"
for value in stmt_values:
insert_stmt = insert_stmt + ", %s"
for field in stmt_fields:
if field in all_fields_for_table:
insert_stmt = insert_stmt + ", %s"
insert_stmt = insert_stmt + ")"

#
# Try the select to see if we need to insert
#
#self._logger.debug(select_stmt)
#self._logger.debug(insert_stmt)
#self._logger.debug(str(insert_stmt_values))

cursor = self.get_cursor()

values = tuple(stmt_values)
values = tuple(insert_stmt_values)
cursor.execute( select_stmt, values )
rows = cursor.fetchone()
if rows is not None:
Expand All @@ -607,8 +621,8 @@ def _select_insert(self, table, table_id, stmt_fields, stmt_values):
self._logger.debug( ", ".join(str(x) for x in values) )
found_id = self._get_nextval( "%s_%s_seq" % (table, table_id))

stmt_values.insert(0, found_id)
values = tuple(stmt_values)
insert_stmt_values.insert(0, found_id)
values = tuple(insert_stmt_values)
cursor.execute( insert_stmt, values )
# Make sure to commit after every INSERT
self._connection.commit()
Expand Down