Skip to content

Commit

Permalink
Add support for the WHERE clause in copy_to methods (#941)
Browse files Browse the repository at this point in the history
  • Loading branch information
redgoldlace authored Oct 9, 2023
1 parent 70c8bd8 commit b7ffab6
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 14 deletions.
58 changes: 51 additions & 7 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ async def copy_to_table(self, table_name, *, source,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
force_not_null=None, force_null=None,
encoding=None):
encoding=None, where=None):
"""Copy data to the specified table.
:param str table_name:
Expand All @@ -885,6 +885,15 @@ async def copy_to_table(self, table_name, *, source,
:param str schema_name:
An optional schema name to qualify the table.
:param str where:
An optional SQL expression used to filter rows when copying.
.. note::
Usage of this parameter requires support for the
``COPY FROM ... WHERE`` syntax, introduced in
PostgreSQL version 12.
:param float timeout:
Optional timeout value in seconds.
Expand Down Expand Up @@ -912,6 +921,9 @@ async def copy_to_table(self, table_name, *, source,
https://www.postgresql.org/docs/current/static/sql-copy.html
.. versionadded:: 0.11.0
.. versionadded:: 0.29.0
Added the *where* parameter.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
Expand All @@ -923,21 +935,22 @@ async def copy_to_table(self, table_name, *, source,
else:
cols = ''

cond = self._format_copy_where(where)
opts = self._format_copy_opts(
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
null=null, header=header, quote=quote, escape=escape,
force_not_null=force_not_null, force_null=force_null,
encoding=encoding
)

copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
tab=tabname, cols=cols, opts=opts, cond=cond)

return await self._copy_in(copy_stmt, source, timeout)

async def copy_records_to_table(self, table_name, *, records,
columns=None, schema_name=None,
timeout=None):
timeout=None, where=None):
"""Copy a list of records to the specified table using binary COPY.
:param str table_name:
Expand All @@ -954,6 +967,16 @@ async def copy_records_to_table(self, table_name, *, records,
:param str schema_name:
An optional schema name to qualify the table.
:param str where:
An optional SQL expression used to filter rows when copying.
.. note::
Usage of this parameter requires support for the
``COPY FROM ... WHERE`` syntax, introduced in
PostgreSQL version 12.
:param float timeout:
Optional timeout value in seconds.
Expand Down Expand Up @@ -998,6 +1021,9 @@ async def copy_records_to_table(self, table_name, *, records,
.. versionchanged:: 0.24.0
The ``records`` argument may be an asynchronous iterable.
.. versionadded:: 0.29.0
Added the *where* parameter.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
Expand All @@ -1015,14 +1041,27 @@ async def copy_records_to_table(self, table_name, *, records,

intro_ps = await self._prepare(intro_query, use_cache=True)

cond = self._format_copy_where(where)
opts = '(FORMAT binary)'

copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
tab=tabname, cols=cols, opts=opts, cond=cond)

return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_ps._state, timeout)

def _format_copy_where(self, where):
if where and not self._server_caps.sql_copy_from_where:
raise exceptions.UnsupportedServerFeatureError(
'the `where` parameter requires PostgreSQL 12 or later')

if where:
where_clause = 'WHERE ' + where
else:
where_clause = ''

return where_clause

def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
escape=None, force_quote=None, force_not_null=None,
Expand Down Expand Up @@ -2404,7 +2443,7 @@ class _ConnectionProxy:
ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
'sql_close_all', 'jit'])
'sql_close_all', 'sql_copy_from_where', 'jit'])
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'


Expand All @@ -2417,6 +2456,7 @@ def _detect_server_capabilities(server_version, connection_settings):
sql_reset = True
sql_close_all = False
jit = False
sql_copy_from_where = False
elif hasattr(connection_settings, 'crdb_version'):
# CockroachDB detected.
advisory_locks = False
Expand All @@ -2425,6 +2465,7 @@ def _detect_server_capabilities(server_version, connection_settings):
sql_reset = False
sql_close_all = False
jit = False
sql_copy_from_where = False
elif hasattr(connection_settings, 'crate_version'):
# CrateDB detected.
advisory_locks = False
Expand All @@ -2433,6 +2474,7 @@ def _detect_server_capabilities(server_version, connection_settings):
sql_reset = False
sql_close_all = False
jit = False
sql_copy_from_where = False
else:
# Standard PostgreSQL server assumed.
advisory_locks = True
Expand All @@ -2441,13 +2483,15 @@ def _detect_server_capabilities(server_version, connection_settings):
sql_reset = True
sql_close_all = True
jit = server_version >= (11, 0)
sql_copy_from_where = server_version.major >= 12

return ServerCapabilities(
advisory_locks=advisory_locks,
notifications=notifications,
plpgsql=plpgsql,
sql_reset=sql_reset,
sql_close_all=sql_close_all,
sql_copy_from_where=sql_copy_from_where,
jit=jit,
)

Expand Down
7 changes: 6 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'ClientConfigurationError',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
'ClientConfigurationError')
'UnsupportedServerFeatureError')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -233,6 +234,10 @@ class UnsupportedClientFeatureError(InterfaceError):
"""Requested feature is unsupported by asyncpg."""


class UnsupportedServerFeatureError(InterfaceError):
"""Requested feature is unsupported by PostgreSQL server."""


class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""

Expand Down
12 changes: 8 additions & 4 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,8 @@ async def copy_to_table(
force_quote=None,
force_not_null=None,
force_null=None,
encoding=None
encoding=None,
where=None
):
"""Copy data to the specified table.
Expand Down Expand Up @@ -740,7 +741,8 @@ async def copy_to_table(
force_quote=force_quote,
force_not_null=force_not_null,
force_null=force_null,
encoding=encoding
encoding=encoding,
where=where
)

async def copy_records_to_table(
Expand All @@ -750,7 +752,8 @@ async def copy_records_to_table(
records,
columns=None,
schema_name=None,
timeout=None
timeout=None,
where=None
):
"""Copy a list of records to the specified table using binary COPY.
Expand All @@ -767,7 +770,8 @@ async def copy_records_to_table(
records=records,
columns=columns,
schema_name=schema_name,
timeout=timeout
timeout=timeout,
where=where
)

def acquire(self, *, timeout=None):
Expand Down
35 changes: 33 additions & 2 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io
import os
import tempfile
import unittest

import asyncpg
from asyncpg import _testbase as tb
Expand Down Expand Up @@ -414,7 +415,7 @@ async def test_copy_to_table_basics(self):
'*a4*|b4',
'*a5*|b5',
'*!**|*n-u-l-l*',
'n-u-l-l|bb'
'n-u-l-l|bb',
]).encode('utf-8')
)
f.seek(0)
Expand Down Expand Up @@ -644,6 +645,35 @@ async def test_copy_records_to_table_1(self):
finally:
await self.con.execute('DROP TABLE copytab')

async def test_copy_records_to_table_where(self):
if not self.con._server_caps.sql_copy_from_where:
raise unittest.SkipTest(
'COPY WHERE not supported on server')

await self.con.execute('''
CREATE TABLE copytab_where(a text, b int, c timestamptz);
''')

try:
date = datetime.datetime.now(tz=datetime.timezone.utc)
delta = datetime.timedelta(days=1)

records = [
('a-{}'.format(i), i, date + delta)
for i in range(100)
]

records.append(('a-100', None, None))
records.append(('b-999', None, None))

res = await self.con.copy_records_to_table(
'copytab_where', records=records, where='a <> \'b-999\'')

self.assertEqual(res, 'COPY 101')

finally:
await self.con.execute('DROP TABLE copytab_where')

async def test_copy_records_to_table_async(self):
await self.con.execute('''
CREATE TABLE copytab_async(a text, b int, c timestamptz);
Expand All @@ -660,7 +690,8 @@ async def record_generator():
yield ('a-100', None, None)

res = await self.con.copy_records_to_table(
'copytab_async', records=record_generator())
'copytab_async', records=record_generator(),
)

self.assertEqual(res, 'COPY 101')

Expand Down

0 comments on commit b7ffab6

Please sign in to comment.