From ebcf3af41d25aad94af079c27c553082d2ea0914 Mon Sep 17 00:00:00 2001 From: Leonard Besson Date: Sat, 23 Dec 2023 16:49:45 +0100 Subject: [PATCH] Allow executemany to return rows --- asyncpg/connection.py | 20 ++++++++++--- asyncpg/pool.py | 6 ++-- asyncpg/prepared_stmt.py | 16 ++++++++-- asyncpg/protocol/coreproto.pxd | 2 +- asyncpg/protocol/coreproto.pyx | 6 ++-- asyncpg/protocol/protocol.pyx | 4 ++- tests/test_execute.py | 53 +++++++++++++++++++++++++++++++++- 7 files changed, 92 insertions(+), 15 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0367e365..fde63c36 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -359,7 +359,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: ) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany(self, command: str, args, *, timeout: float=None, + return_rows: bool=False): """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -373,7 +374,13 @@ async def executemany(self, command: str, args, *, timeout: float=None): :param command: Command to execute. :param args: An iterable containing sequences of arguments. :param float timeout: Optional timeout value in seconds. - :return None: This method discards the results of the operations. + :param bool return_rows: + If ``True``, the resulting rows of each command will be + returned as a list of :class:`~asyncpg.Record` + (defaults to ``False``). + :return: + None, or a list of :class:`~asyncpg.Record` instances + if `return_rows` is true. .. versionadded:: 0.7.0 @@ -386,9 +393,13 @@ async def executemany(self, command: str, args, *, timeout: float=None): to prior versions, where the effect of already-processed iterations would remain in place when an error has occurred, unless ``executemany()`` was called in a transaction. + + .. versionchanged:: 0.30.0 + Added `return_rows` keyword-only parameter. """ self._check_open() - return await self._executemany(command, args, timeout) + return await self._executemany( + command, args, timeout, return_rows=return_rows) async def _get_statement( self, @@ -1898,12 +1909,13 @@ async def __execute( ) return result, stmt - async def _executemany(self, query, args, timeout): + async def _executemany(self, query, args, timeout, return_rows): executor = lambda stmt, timeout: self._protocol.bind_execute_many( state=stmt, args=args, portal_name='', timeout=timeout, + return_rows=return_rows, ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 06e698df..acb82ffa 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -538,7 +538,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: async with self.acquire() as con: return await con.execute(query, *args, timeout=timeout) - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany(self, command: str, args, *, timeout: float=None, + return_rows: bool=False): """Execute an SQL *command* for each sequence of arguments in *args*. Pool performs this operation using one of its connections. Other than @@ -549,7 +550,8 @@ async def executemany(self, command: str, args, *, timeout: float=None): .. versionadded:: 0.10.0 """ async with self.acquire() as con: - return await con.executemany(command, args, timeout=timeout) + return await con.executemany( + command, args, timeout=timeout, return_rows=return_rows) async def fetch( self, diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 8e241d67..114e366d 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -211,18 +211,28 @@ async def fetchrow(self, *args, timeout=None): return data[0] @connresource.guarded - async def executemany(self, args, *, timeout: float=None): + async def executemany(self, args, *, timeout: float=None, + return_rows: bool=False): """Execute the statement for each sequence of arguments in *args*. :param args: An iterable containing sequences of arguments. :param float timeout: Optional timeout value in seconds. - :return None: This method discards the results of the operations. + :param bool return_rows: + If ``True``, the resulting rows of each command will be + returned as a list of :class:`~asyncpg.Record` + (defaults to ``False``). + :return: + None, or a list of :class:`~asyncpg.Record` instances + if `return_rows` is true. .. versionadded:: 0.22.0 + + .. versionchanged:: 0.30.0 + Added `return_rows` keyword-only parameter. """ return await self.__do_execute( lambda protocol: protocol.bind_execute_many( - self._state, args, '', timeout)) + self._state, args, '', timeout, return_rows=return_rows)) async def __do_execute(self, executor): protocol = self._connection._protocol diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 7ce4f574..b303f92f 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -174,7 +174,7 @@ cdef class CoreProtocol: cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data) + object bind_data, bint return_rows) cdef bint _bind_execute_many_more(self, bint first=*) cdef _bind_execute_many_fail(self, object error, bint first=*) cdef _bind(self, str portal_name, str stmt_name, diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 64afe934..5bd0b305 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -940,12 +940,12 @@ cdef class CoreProtocol: self._send_bind_message(portal_name, stmt_name, bind_data, limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data): + object bind_data, bint return_rows): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) - self.result = None - self._discard_data = True + self.result = [] if return_rows else None + self._discard_data = not return_rows self._execute_iter = bind_data self._execute_portal_name = portal_name self._execute_stmt_name = stmt_name diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index b43b0e9c..20cc8d47 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -213,6 +213,7 @@ cdef class BaseProtocol(CoreProtocol): args, portal_name: str, timeout, + return_rows: bool, ): if self.cancel_waiter is not None: await self.cancel_waiter @@ -238,7 +239,8 @@ cdef class BaseProtocol(CoreProtocol): more = self._bind_execute_many( portal_name, state.name, - arg_bufs) # network op + arg_bufs, + return_rows) # network op self.last_query = state.query self.statement = state diff --git a/tests/test_execute.py b/tests/test_execute.py index 78d8c124..2eed59e4 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -139,6 +139,45 @@ async def test_executemany_basic(self): ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) + async def test_executemany_returning(self): + result = await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) RETURNING a, b + ''', [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ], return_rows=True) + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + # Empty set + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) RETURNING a, b + ''', (), return_rows=True) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + # Without "RETURNING" + result = await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [('e', 5), ('f', 6)], return_rows=True) + self.assertEqual(result, []) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6) + ]) + async def test_executemany_bad_input(self): with self.assertRaisesRegex( exceptions.DataError, @@ -288,11 +327,13 @@ async def test_executemany_client_server_failure_conflict(self): async def test_executemany_prepare(self): stmt = await self.con.prepare(''' - INSERT INTO exmany VALUES($1, $2) + INSERT INTO exmany VALUES($1, $2) RETURNING a, b ''') result = await stmt.executemany([ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) + # While the query contains a "RETURNING" clause, by default + # `executemany` does not return anything self.assertIsNone(result) result = await self.con.fetch(''' SELECT * FROM exmany @@ -308,3 +349,13 @@ async def test_executemany_prepare(self): self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) + # Now with `return_rows=True`, we should retrieve the tuples + # from the "RETURNING" clause. + result = await stmt.executemany([('e', 5), ('f', 6)], return_rows=True) + self.assertEqual(result, [('e', 5), ('f', 6)]) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6) + ])