Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement local caching for WMTS requests #2316

Merged
merged 17 commits into from
Jan 3, 2025
83 changes: 68 additions & 15 deletions lib/cartopy/io/ogc_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import collections
import io
import math
import os
from pathlib import Path
from urllib.parse import urlparse
import warnings
import weakref
Expand All @@ -27,6 +29,8 @@
from PIL import Image
import shapely.geometry as sgeom

import cartopy


try:
import owslib.util
Expand Down Expand Up @@ -357,7 +361,7 @@ class WMTSRasterSource(RasterSource):

"""

def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):
def __init__(self, wmts, layer_name, gettile_extra_kwargs=None, cache=False):
dnowacki-usgs marked this conversation as resolved.
Show resolved Hide resolved
"""
Parameters
----------
Expand All @@ -368,6 +372,9 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):
gettile_extra_kwargs: dict, optional
Extra keywords (e.g. time) to pass through to the
service's gettile method.
cache : bool or str, optional
If True, the default cache directory is used. If False, no cache is
used. If a string, the string is used as the path to the cache.

"""
if WebMapService is None:
Expand Down Expand Up @@ -397,6 +404,18 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):

self._matrix_set_name_map = {}

# Enable a cache mechanism when cache is equal to True or to a path.
self._default_cache = False
if cache is True:
self._default_cache = True
self.cache_path = Path(cartopy.config["cache_dir"])
elif cache is False:
self.cache_path = None
else:
self.cache_path = Path(cache)
self.cache = set({})
self._load_cache()

def _matrix_set_name(self, target_projection):
key = id(target_projection)
matrix_set_name = self._matrix_set_name_map.get(key)
Expand Down Expand Up @@ -510,6 +529,23 @@ def fetch_raster(self, projection, extent, target_resolution):

return located_images

@property
def _cache_dir(self):
"""Return the name of the cache directory"""
return self.cache_path / self.__class__.__name__

def _load_cache(self):
"""Load the cache"""
if self.cache_path is not None:
cache_dir = self._cache_dir
if not cache_dir.exists():
os.makedirs(cache_dir)
if self._default_cache:
warnings.warn(
'Cartopy created the following directory to cache '
f'WMTSRasterSource tiles: {cache_dir}')
self.cache = self.cache.union(set(cache_dir.iterdir()))

def _choose_matrix(self, tile_matrices, meters_per_unit, max_pixel_span):
# Get the tile matrices in order of increasing resolution.
tile_matrices = sorted(tile_matrices,
Expand Down Expand Up @@ -642,21 +678,38 @@ def _wmts_images(self, wmts, layer, matrix_set_name, extent,
# Get the tile's Image from the cache if possible.
img_key = (row, col)
img = image_cache.get(img_key)

if img is None:
try:
tile = wmts.gettile(
layer=layer.id,
tilematrixset=matrix_set_name,
tilematrix=str(tile_matrix_id),
row=str(row), column=str(col),
**self.gettile_extra_kwargs)
except owslib.util.ServiceException as exception:
if ('TileOutOfRange' in exception.message and
ignore_out_of_range):
continue
raise exception
img = Image.open(io.BytesIO(tile.read()))
image_cache[img_key] = img
# Try it from disk cache
if self.cache_path is not None:
filename = f"{img_key[0]}_{img_key[1]}.npy"
cached_file = self._cache_dir / filename
else:
filename = None
cached_file = None

if cached_file in self.cache:
img = Image.fromarray(np.load(cached_file, allow_pickle=False))
else:
try:
tile = wmts.gettile(
layer=layer.id,
tilematrixset=matrix_set_name,
tilematrix=str(tile_matrix_id),
row=str(row), column=str(col),
**self.gettile_extra_kwargs)
except owslib.util.ServiceException as exception:
if ('TileOutOfRange' in exception.message and
ignore_out_of_range):
continue
raise exception
img = Image.open(io.BytesIO(tile.read()))
image_cache[img_key] = img
# save image to local cache
if self.cache_path is not None:
np.save(cached_file, img, allow_pickle=False)
self.cache.add(filename)

if big_img is None:
size = (img.size[0] * n_cols, img.size[1] * n_rows)
big_img = Image.new('RGBA', size, (255, 255, 255, 255))
Expand Down
4 changes: 2 additions & 2 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,7 +2224,7 @@ def streamplot(self, x, y, u, v, **kwargs):
sp = super().streamplot(x, y, u, v, **kwargs)
return sp

def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
def add_wmts(self, wmts, layer_name, wmts_kwargs=None, cache=False, **kwargs):
"""
Add the specified WMTS layer to the axes.

Expand All @@ -2249,7 +2249,7 @@ def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
"""
from cartopy.io.ogc_clients import WMTSRasterSource
wmts = WMTSRasterSource(wmts, layer_name,
gettile_extra_kwargs=wmts_kwargs)
gettile_extra_kwargs=wmts_kwargs, cache=cache)
return self.add_raster(wmts, **kwargs)

def add_wms(self, wms, layers, wms_kwargs=None, **kwargs):
Expand Down
78 changes: 78 additions & 0 deletions lib/cartopy/tests/test_img_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from cartopy import config
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
import cartopy.io.ogc_clients as ogc


RESOLUTION = (30, 30)

#: Maps Google tile coordinates to native mercator coordinates as defined
#: by https://goo.gl/pgJi.
KNOWN_EXTENTS = {(0, 0, 0): (-20037508.342789244, 20037508.342789244,
Expand Down Expand Up @@ -328,6 +331,81 @@ def test_azuremaps_get_image():
assert extent1 == extent2


@pytest.mark.network
@pytest.mark.parametrize('cache_dir', ["tmpdir", True, False])
dnowacki-usgs marked this conversation as resolved.
Show resolved Hide resolved
def test_wmts_cache(cache_dir, tmp_path):
if cache_dir == "tmpdir":
tmpdir_str = str(tmp_path)
else:
tmpdir_str = cache_dir

if cache_dir is True:
config["cache_dir"] = str(tmp_path)

# URI = 'https://map1c.vis.earthdata.nasa.gov/wmts-geo/wmts.cgi'
# layer_name = 'VIIRS_CityLights_2012'
URI = 'https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/WMTS/1.0.0/WMTSCapabilities.xml'
layer_name='USGSImageryOnly'
projection = ccrs.PlateCarree()

# Fetch tiles and save them in the cache
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
source = ogc.WMTSRasterSource(URI, layer_name, cache=tmpdir_str)
extent = [-10, 10, 40, 60]
located_image, = source.fetch_raster(projection, extent,
RESOLUTION)

# Do not check the result if the cache is disabled
if cache_dir is False:
assert source.cache_path is None
return

# Check that the warning is properly raised (only when cache is True)
if cache_dir is True:
assert len(w) == 1
else:
assert len(w) == 0

# Define expected results
x_y_f_h = [
(1, 1, '1_1.npy', '0de548bd47e4579ae0500da6ceeb08e7'),
(1, 2, '1_2.npy', '4beebcd3e4408af5accb440d7b4c8933'),
]

# Check the results
cache_dir_res = source.cache_path / "WMTSRasterSource"
files = list(cache_dir_res.iterdir())
hashes = {
f:
hashlib.md5(
np.load(cache_dir_res / f, allow_pickle=True).data
).hexdigest()
for f in files
}
assert sorted(files) == [cache_dir_res / f for x, y, f, h in x_y_f_h]
assert set(files) == set([cache_dir_res / c for c in source.cache])

assert sorted(hashes.values()) == sorted(
h for x, y, f, h in x_y_f_h
)

# Update images in cache (all white)
for f in files:
filename = cache_dir_res / f
img = np.load(filename, allow_pickle=True)
img.fill(255)
np.save(filename, img, allow_pickle=True)

wmts_cache = ogc.WMTSRasterSource(URI, layer_name, cache=tmpdir_str)
located_image_cache, = wmts_cache.fetch_raster(projection, extent,
RESOLUTION)

# Check that the new fetch_raster() call used cached images
assert wmts_cache.cache == set([cache_dir_res / c for c in source.cache])
assert (np.array(located_image_cache.image) == 255).all()


@pytest.mark.network
@pytest.mark.parametrize('cache_dir', ["tmpdir", True, False])
def test_cache(cache_dir, tmp_path):
Expand Down
Loading