Skip to content

Commit

Permalink
feat(jinja): improve url parameter formatting (apache#16711)
Browse files Browse the repository at this point in the history
* feat(jinja): improve url parameter formatting

* add UPDATING.md

* fix test
  • Loading branch information
villebro authored Sep 15, 2021
1 parent fb4650a commit 88c09c2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
3 changes: 3 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ assists people when migrating to a new version.
## Next

### Breaking Changes

- [16711](https://github.com/apache/incubator-superset/pull/16711): The `url_param` Jinja function will now by default escape the result. For instance, the value `O'Brien` will now be changed to `O''Brien`. To disable this behavior, call `url_param` with `escape_result` set to `False`: `url_param("my_key", "my default", escape_result=False)`.

### Potential Downtime
### Deprecations
### Other
Expand Down
22 changes: 20 additions & 2 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import String
from typing_extensions import TypedDict

from superset.exceptions import SupersetTemplateException
Expand Down Expand Up @@ -95,9 +97,11 @@ def __init__(
self,
extra_cache_keys: Optional[List[Any]] = None,
removed_filters: Optional[List[str]] = None,
dialect: Optional[Dialect] = None,
):
self.extra_cache_keys = extra_cache_keys
self.removed_filters = removed_filters if removed_filters is not None else []
self.dialect = dialect

def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
"""
Expand Down Expand Up @@ -145,7 +149,11 @@ def cache_key_wrapper(self, key: Any) -> Any:
return key

def url_param(
self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True
self,
param: str,
default: Optional[str] = None,
add_to_cache_keys: bool = True,
escape_result: bool = True,
) -> Optional[str]:
"""
Read a url or post parameter and use it in your SQL Lab query.
Expand All @@ -166,6 +174,7 @@ def url_param(
:param param: the parameter to lookup
:param default: the value to return in the absence of the parameter
:param add_to_cache_keys: Whether the value should be included in the cache key
:param escape_result: Should special characters in the result be escaped
:returns: The URL parameters
"""

Expand All @@ -178,6 +187,11 @@ def url_param(
form_data, _ = get_form_data()
url_params = form_data.get("url_params") or {}
result = url_params.get(param, default)
if result and escape_result and self.dialect:
# use the dialect specific quoting logic to escape string
result = String().literal_processor(dialect=self.dialect)(value=result)[
1:-1
]
if add_to_cache_keys:
self.cache_key_wrapper(result)
return result
Expand Down Expand Up @@ -430,7 +444,11 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
class JinjaTemplateProcessor(BaseTemplateProcessor):
def set_context(self, **kwargs: Any) -> None:
super().set_context(**kwargs)
extra_cache = ExtraCache(self._extra_cache_keys, self._removed_filters)
extra_cache = ExtraCache(
extra_cache_keys=self._extra_cache_keys,
removed_filters=self._removed_filters,
dialect=self._database.get_dialect(),
)
self._context.update(
{
"url_param": partial(safe_proxy, extra_cache.url_param),
Expand Down
13 changes: 11 additions & 2 deletions tests/integration_tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import dialect

from tests.integration_tests.test_app import app
from superset.sql_parse import CtasMethod
Expand Down Expand Up @@ -422,15 +424,22 @@ def create_fake_db_for_macros(self):
self.login(username="admin")
database_name = "db_for_macros_testing"
db_id = 200
return self.get_or_create(
database = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
sqlalchemy_uri="db_for_macros_testing://user@host:8080/hive",
id=db_id,
)

def delete_fake_db_for_macros(self):
def mock_get_dialect() -> Dialect:
return dialect()

database.get_dialect = mock_get_dialect
return database

@staticmethod
def delete_fake_db_for_macros():
database = (
db.session.query(Database)
.filter(Database.database_name == "db_for_macros_testing")
Expand Down
31 changes: 31 additions & 0 deletions tests/integration_tests/jinja_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest import mock

import pytest
from sqlalchemy.dialects.postgresql import dialect

import tests.integration_tests.test_app
from superset import app
Expand Down Expand Up @@ -199,6 +200,36 @@ def test_url_param_form_data(self) -> None:
cache = ExtraCache()
self.assertEqual(cache.url_param("foo"), "bar")

def test_url_param_escaped_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(cache.url_param("foo"), "O''Brien")

def test_url_param_escaped_default_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(cache.url_param("bar", "O'Malley"), "O''Malley")

def test_url_param_unescaped_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(cache.url_param("foo", escape_result=False), "O'Brien")

def test_url_param_unescaped_default_form_data(self) -> None:
with app.test_request_context(
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
):
cache = ExtraCache(dialect=dialect())
self.assertEqual(
cache.url_param("bar", "O'Malley", escape_result=False), "O'Malley"
)

def test_safe_proxy_primitive(self) -> None:
def func(input: Any) -> Any:
return input
Expand Down

0 comments on commit 88c09c2

Please sign in to comment.