Skip to content

Commit

Permalink
Now the CORS header accepts every origin. Also updated unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
jedef committed Jun 29, 2022
1 parent 6911d28 commit 32ab3f6
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 72 deletions.
33 changes: 26 additions & 7 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from flask import Flask
from flask import g
from flask import request
from itertools import chain

from app import settings
from app.helpers import init_logging
from app.helpers.raster.georaster import GeoRasterUtils
from app.helpers.url import ALLOWED_DOMAINS_PATTERN
from app.middleware import ReverseProxy

# Initialize Logging using JSON format for all loggers and using the Stream Handler.
Expand Down Expand Up @@ -38,15 +38,34 @@ def log_route():
# Add CORS Headers to all request
@app.after_request
def add_cors_header(response):
if (
'Origin' in request.headers and
re.match(ALLOWED_DOMAINS_PATTERN, request.headers['Origin'])
):
response.headers['Access-Control-Allow-Origin'] = request.headers['Origin']
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
response.headers['Access-Control-Allow-Origin'] = "*"
response.headers['Access-Control-Allow-Headers'] = "*"
response.headers.set(
'Access-Control-Allow-Methods', ', '.join(get_registered_method(app, request.url_rule))
)
return response


# Helper method for add_cors_header
def get_registered_method(app, url_rule):
'''Returns the list of registered method for the given endpoint'''

# The list of registered method is taken from the werkzeug.routing.Rule. A Rule object
# has a methods property with the list of allowed method on an endpoint. If this property is
# missing then all methods are allowed.
# See https://werkzeug.palletsprojects.com/en/2.0.x/routing/#werkzeug.routing.Rule
all_methods = ['GET', 'HEAD', 'OPTIONS', 'POST', 'PUT', 'DELETE']
return set(
chain.from_iterable(
[
r.methods if r.methods else all_methods
for r in app.url_map.iter_rules()
if r.rule == str(url_rule)
]
)
)


# NOTE it is better to have this method registered last (after add_cors_header) otherwise
# the response might not be correct (e.g. headers added in another after_request hook).
@app.after_request
Expand Down
7 changes: 0 additions & 7 deletions app/helpers/url.py

This file was deleted.

5 changes: 4 additions & 1 deletion tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

ENDPOINT_FOR_JSON_PROFILE = '/rest/services/profile.json'
ENDPOINT_FOR_CSV_PROFILE = '/rest/services/profile.csv'
DEFAULT_HEADERS = {'Origin': 'https://map.geo.admin.ch'}
DEFAULT_INTERN_HEADERS = {'Origin': 'https://map.geo.admin.ch'}
DEFAULT_EXTERN_HEADERS = {'Origin': 'https://extern-company.com'}
# Should accept anyone, as it is a public api
DEFAULT_HEADERS = DEFAULT_EXTERN_HEADERS

POINT_1_LV03 = [630000, 170000]
POINT_2_LV03 = [634000, 173000]
Expand Down
60 changes: 60 additions & 0 deletions tests/unit_tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest

from mock import patch

with patch('os.path.exists') as mock_exists:
mock_exists.return_value = True
import app as service_alti

from flask.helpers import url_for
from tests.unit_tests import DEFAULT_HEADERS
from tests.unit_tests import DEFAULT_INTERN_HEADERS
from tests.unit_tests import DEFAULT_EXTERN_HEADERS


class BaseRouteTestCase(unittest.TestCase):

def setUp(self) -> None:
self.test_instance = service_alti.app.test_client()
self.context = service_alti.app.test_request_context()
self.context.push()
service_alti.app.config['TESTING'] = True
self.headers = DEFAULT_HEADERS

def check_response(self, response, expected_status=200, expected_allowed_methods=None):
if expected_allowed_methods is None:
expected_allowed_methods = ['GET', 'HEAD', 'OPTIONS']
self.assertIsNotNone(response)
self.assertEqual(response.status_code, expected_status, msg=response.get_data(as_text=True))
self.assertCors(response, expected_allowed_methods)

def assertCors(self, response, expected_allowed_methods): # pylint: disable=invalid-name
self.assertIn('Access-Control-Allow-Origin', response.headers)
self.assertEqual(response.headers['Access-Control-Allow-Origin'], '*')
self.assertIn('Access-Control-Allow-Methods', response.headers)
self.assertListEqual(
sorted(expected_allowed_methods),
sorted(
map(
lambda m: m.strip(),
response.headers['Access-Control-Allow-Methods'].split(',')
)
)
)
self.assertIn('Access-Control-Allow-Headers', response.headers)
self.assertEqual(response.headers['Access-Control-Allow-Headers'], '*')


class CheckerTests(BaseRouteTestCase):

def test_checker_intern_origin(self):
response = self.test_instance.get(url_for('check'), headers=DEFAULT_INTERN_HEADERS)
self.check_response(response)
self.assertNotIn('Cache-Control', response.headers)
self.assertEqual(response.content_type, "application/json")

def test_checker_extern_origin(self):
response = self.test_instance.get(url_for('check'), headers=DEFAULT_EXTERN_HEADERS)
self.check_response(response)
self.assertNotIn('Cache-Control', response.headers)
self.assertEqual(response.content_type, "application/json")
17 changes: 3 additions & 14 deletions tests/unit_tests/test_height.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# -*- coding: utf-8 -*-
import unittest

from mock import Mock
from mock import patch

with patch('os.path.exists') as mock_exists:
mock_exists.return_value = True
import app as service_alti

from tests.unit_tests import DEFAULT_HEADERS
from tests.unit_tests.test_base import BaseRouteTestCase

EAST_LV03, NORTH_LV03 = 632510.0, 170755.0
# LV95
Expand All @@ -17,14 +12,9 @@
HEIGHT_DTM2, HEIGHT_DTM25 = 568.2, 567.6


class TestHeight(unittest.TestCase):
class TestHeight(BaseRouteTestCase):
# pylint: disable=too-many-public-methods

def setUp(self) -> None:
service_alti.app.config['TESTING'] = True
self.test_instance = service_alti.app.test_client()
self.headers = DEFAULT_HEADERS

def __test_get(self, params):
return self.test_instance.get(
'/rest/services/height', query_string=params, headers=self.headers
Expand All @@ -39,8 +29,7 @@ def __prepare_mock_and_test_get(
raster_mock.get_height_for_coordinate.return_value = return_value
mock_georaster_utils.get_raster.return_value = raster_mock
response = self.__test_get(params)
self.assertIsNotNone(response)
self.assertEqual(response.status_code, expected_status, msg=response.data)
self.check_response(response, expected_status)
return response

def __assert_height(self, response, expected_height):
Expand Down
21 changes: 6 additions & 15 deletions tests/unit_tests/test_profile.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# -*- coding: utf-8 -*-
import logging
import unittest

from mock import patch

with patch('os.path.exists') as mock_exists:
mock_exists.return_value = True
import app as service_alti

from app.helpers.profile_helpers import PROFILE_DEFAULT_AMOUNT_POINTS
from app.helpers.profile_helpers import PROFILE_MAX_AMOUNT_POINTS
from tests import create_json
from tests.unit_tests import DEFAULT_HEADERS
from tests.unit_tests import ENDPOINT_FOR_CSV_PROFILE
from tests.unit_tests import ENDPOINT_FOR_JSON_PROFILE
from tests.unit_tests import LINESTRING_MISSPELLED_SHAPE
Expand All @@ -24,20 +18,17 @@
from tests.unit_tests import POINT_2_LV03
from tests.unit_tests import POINT_3_LV03
from tests.unit_tests import prepare_mock
from tests.unit_tests.test_base import BaseRouteTestCase

logger = logging.getLogger(__name__)


class TestProfileBase(unittest.TestCase):

def setUp(self) -> None:
service_alti.app.config['TESTING'] = True
self.test_instance = service_alti.app.test_client()
self.headers = DEFAULT_HEADERS
class TestProfileBase(BaseRouteTestCase):

def check_response(self, response, expected_status=200):
self.assertIsNotNone(response)
self.assertEqual(response.status_code, expected_status, msg=response.get_data(as_text=True))
def check_response(self, response, expected_status=200, expected_allowed_methods=None):
if expected_allowed_methods is None:
expected_allowed_methods = ['GET', 'HEAD', 'POST', 'OPTIONS']
super().check_response(response, expected_status, expected_allowed_methods)

def assert_response_contains(self, response, content):
self.assertTrue(
Expand Down
45 changes: 17 additions & 28 deletions tests/unit_tests/test_profile_validation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
# -*- coding: utf-8 -*-
import logging
import unittest

from mock import patch

with patch('os.path.exists') as mock_exists:
mock_exists.return_value = True
import app as service_alti

from app.helpers.profile_helpers import PROFILE_DEFAULT_AMOUNT_POINTS
from app.helpers.profile_helpers import PROFILE_MAX_AMOUNT_POINTS
from tests import create_json
from tests.unit_tests import DEFAULT_HEADERS
from tests.unit_tests import prepare_mock

from tests.unit_tests.test_profile import TestProfileBase

logger = logging.getLogger(__name__)

INVALID_LINESTRING_NOT_GEOJSON = "hello world"
Expand All @@ -24,15 +21,7 @@
INVALID_OFFSET = "hello world"


class TestProfileValidation(unittest.TestCase):

def setUp(self) -> None:
service_alti.app.config['TESTING'] = True
self.test_instance = service_alti.app.test_client()

def assert_response(self, response, expected_status=200):
self.assertIsNotNone(response)
self.assertEqual(response.status_code, expected_status, msg=response.data)
class TestProfileValidation(TestProfileBase):

def prepare_mock_and_test(
self, linestring, spatial_reference, nb_points, offset, mock_georaster_utils
Expand All @@ -59,7 +48,7 @@ def test_profile_validation_valid(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response)
self.check_response(response)
profile = response.get_json()
self.assertEqual(VALID_NB_POINTS, len(profile))

Expand All @@ -72,7 +61,7 @@ def test_profile_validation_valid_nb_points_none(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response)
self.check_response(response)
profile = response.get_json()
self.assertEqual(PROFILE_DEFAULT_AMOUNT_POINTS, len(profile))

Expand All @@ -85,7 +74,7 @@ def test_profile_validation_valid_offset_none(self, mock_georaster_utils):
offset=None,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response)
self.check_response(response)

@patch('app.routes.georaster_utils')
def test_profile_validation_wrong_content_type(self, mock_georaster_utils):
Expand All @@ -96,7 +85,7 @@ def test_profile_validation_wrong_content_type(self, mock_georaster_utils):
**DEFAULT_HEADERS, 'Content-Type': 'text/plain'
}
)
self.assert_response(response, expected_status=415)
self.check_response(response, expected_status=415)

@patch('app.routes.georaster_utils')
def test_profile_validation_no_linestring(self, mock_georaster_utils):
Expand All @@ -107,7 +96,7 @@ def test_profile_validation_no_linestring(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)

@patch('app.routes.georaster_utils')
def test_profile_validation_not_a_geojson_linestring(self, mock_georaster_utils):
Expand All @@ -118,7 +107,7 @@ def test_profile_validation_not_a_geojson_linestring(self, mock_georaster_utils)
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)

@patch('app.routes.georaster_utils')
def test_profile_validation_linestring_too_long(self, mock_georaster_utils):
Expand All @@ -129,7 +118,7 @@ def test_profile_validation_linestring_too_long(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=413)
self.check_response(response, expected_status=413)

@patch('app.routes.georaster_utils')
def test_profile_validation_wrong_srid(self, mock_georaster_utils):
Expand All @@ -140,7 +129,7 @@ def test_profile_validation_wrong_srid(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)

@patch('app.routes.georaster_utils')
def test_profile_validation_nb_points_less_than_two(self, mock_georaster_utils):
Expand All @@ -151,7 +140,7 @@ def test_profile_validation_nb_points_less_than_two(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)

@patch('app.routes.georaster_utils')
def test_profile_validation_nb_points_too_big(self, mock_georaster_utils):
Expand All @@ -162,7 +151,7 @@ def test_profile_validation_nb_points_too_big(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)

@patch('app.routes.georaster_utils')
def test_profile_validation_invalid_nb_points(self, mock_georaster_utils):
Expand All @@ -173,7 +162,7 @@ def test_profile_validation_invalid_nb_points(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)
self.assertEqual(
response.json['error']['message'],
'Please provide a numerical value for the parameter '
Expand All @@ -187,7 +176,7 @@ def test_profile_validation_invalid_nb_points(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)
self.assertEqual(
response.json['error']['message'],
'Please provide a numerical value for the parameter '
Expand All @@ -201,7 +190,7 @@ def test_profile_validation_invalid_nb_points(self, mock_georaster_utils):
offset=VALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)
self.assertEqual(
response.json['error']['message'],
'Please provide a numerical value for the parameter '
Expand All @@ -217,4 +206,4 @@ def test_profile_validation_offset_not_int(self, mock_georaster_utils):
offset=INVALID_OFFSET,
mock_georaster_utils=mock_georaster_utils
)
self.assert_response(response, expected_status=400)
self.check_response(response, expected_status=400)

0 comments on commit 32ab3f6

Please sign in to comment.