Skip to content

Commit

Permalink
make NominatimAPI[Async] a context manager
Browse files Browse the repository at this point in the history
If close() isn't properly called, it can lead to odd error messages
about uncaught exceptions.
  • Loading branch information
lonvia committed Aug 19, 2024
1 parent 8b41b80 commit c2594ac
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 65 deletions.
25 changes: 23 additions & 2 deletions src/nominatim_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
This class shares most of the functions with its synchronous
version. There are some additional functions or parameters,
which are documented below.
This class should usually be used as a context manager in 'with' context.
"""
def __init__(self, project_dir: Path,
environ: Optional[Mapping[str, str]] = None,
Expand Down Expand Up @@ -166,6 +168,14 @@ async def close(self) -> None:
await self._engine.dispose()


async def __aenter__(self) -> 'NominatimAPIAsync':
return self


async def __aexit__(self, *_: Any) -> None:
await self.close()


@contextlib.asynccontextmanager
async def begin(self) -> AsyncIterator[SearchConnection]:
""" Create a new connection with automatic transaction handling.
Expand Down Expand Up @@ -351,6 +361,8 @@ class NominatimAPI:
""" This class provides a thin synchronous wrapper around the asynchronous
Nominatim functions. It creates its own event loop and runs each
synchronous function call to completion using that loop.
This class should usually be used as a context manager in 'with' context.
"""

def __init__(self, project_dir: Path,
Expand All @@ -376,8 +388,17 @@ def close(self) -> None:
This function also closes the asynchronous worker loop making
the NominatimAPI object unusable.
"""
self._loop.run_until_complete(self._async_api.close())
self._loop.close()
if not self._loop.is_closed():
self._loop.run_until_complete(self._async_api.close())
self._loop.close()


def __enter__(self) -> 'NominatimAPI':
return self


def __exit__(self, *_: Any) -> None:
self.close()


@property
Expand Down
7 changes: 7 additions & 0 deletions test/python/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
from pathlib import Path
import pytest
import pytest_asyncio
import time
import datetime as dt

Expand Down Expand Up @@ -244,3 +245,9 @@ def mkapi(apiobj, options=None):

for api in testapis:
api.close()


@pytest_asyncio.fixture
async def api(temp_db):
async with napi.NominatimAPIAsync(Path('/invalid')) as api:
yield api
7 changes: 3 additions & 4 deletions test/python/api/search/test_icu_query_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ async def conn(table_factory):
table_factory('word',
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')

api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn


@pytest.mark.asyncio
Expand Down
7 changes: 3 additions & 4 deletions test/python/api/search/test_legacy_query_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ class TEXT, type TEXT, country_code TEXT,
temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")

api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn


@pytest.mark.asyncio
Expand Down
14 changes: 3 additions & 11 deletions test/python/api/search/test_query_analyzer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,39 @@

import pytest

from nominatim_api import NominatimAPIAsync
from nominatim_api.search.query_analyzer_factory import make_query_analyzer
from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer

@pytest.mark.asyncio
async def test_import_icu_tokenizer(table_factory):
async def test_import_icu_tokenizer(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'icu'),
('tokenizer_import_normalisation', ':: lower();'),
('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))

api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
ana = await make_query_analyzer(conn)

assert isinstance(ana, ICUQueryAnalyzer)
await api.close()


@pytest.mark.asyncio
async def test_import_missing_property(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_property(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')

async with api.begin() as conn:
with pytest.raises(ValueError, match='Property.*not found'):
await make_query_analyzer(conn)
await api.close()


@pytest.mark.asyncio
async def test_import_missing_module(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_module(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'missing'),))

async with api.begin() as conn:
with pytest.raises(RuntimeError, match='Tokenizer not found'):
await make_query_analyzer(conn)
await api.close()

39 changes: 14 additions & 25 deletions test/python/api/test_api_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,34 @@
"""
from pathlib import Path
import pytest
import pytest_asyncio

import sqlalchemy as sa

from nominatim_api import NominatimAPIAsync

@pytest_asyncio.fixture
async def apiobj(temp_db):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
api = NominatimAPIAsync(Path('/invalid'), {})
yield api
await api.close()


@pytest.mark.asyncio
async def test_run_scalar(apiobj, table_factory):
async def test_run_scalar(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))

async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'


@pytest.mark.asyncio
async def test_run_execute(apiobj, table_factory):
async def test_run_execute(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))

async with apiobj.begin() as conn:
async with api.begin() as conn:
result = await conn.execute(sa.text('SELECT * FROM foo'))
assert result.fetchone()[0] == 'a'


@pytest.mark.asyncio
async def test_get_property_existing_cached(apiobj, table_factory):
async def test_get_property_existing_cached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))

async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'

await conn.execute(sa.text('TRUNCATE nominatim_properties'))
Expand All @@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):


@pytest.mark.asyncio
async def test_get_property_existing_uncached(apiobj, table_factory):
async def test_get_property_existing_uncached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))

async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'

await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
Expand All @@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):

@pytest.mark.asyncio
@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
async def test_get_property_missing(apiobj, table_factory, param):
async def test_get_property_missing(api, table_factory, param):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')

async with apiobj.begin() as conn:
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_property(param)


@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
assert await conn.get_db_property('server_version') > 0


@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_db_property('dfkgjd.rijg')
10 changes: 0 additions & 10 deletions test/python/api/test_api_deletable_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,10 @@
from pathlib import Path

import pytest
import pytest_asyncio

from fake_adaptor import FakeAdaptor, FakeError, FakeResponse

import nominatim_api.v1.server_glue as glue
import nominatim_api as napi

@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()


class TestDeletableEndPoint:

Expand Down Expand Up @@ -61,4 +52,3 @@ async def test_deletable(self, api):
{'place_id': 3, 'country_code': 'cd', 'name': None,
'osm_id': 781, 'osm_type': 'R',
'class': 'landcover', 'type': 'grass'}]

9 changes: 0 additions & 9 deletions test/python/api/test_api_polygons_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,10 @@
from pathlib import Path

import pytest
import pytest_asyncio

from fake_adaptor import FakeAdaptor, FakeError, FakeResponse

import nominatim_api.v1.server_glue as glue
import nominatim_api as napi

@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()


class TestPolygonsEndPoint:

Expand Down

0 comments on commit c2594ac

Please sign in to comment.