diff --git a/CHANGES.rst b/CHANGES.rst
index c998a3da..f333413e 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -32,8 +32,10 @@ Fixes
- (:issue:`884`) Oauth re-used POST_LOGIN_VIEW which caused confusion. See below for the new configuration and implications.
- (:pr:`899`) Improve (and simplify) Two-Factor setup. See below for backwards compatability issues and new functionality.
- (:pr:`901`) Work with py_webauthn 2.0
-- (:pr:`xxx`) Remove undocumented and untested looking in session for possible 'next'
+- (:pr:`906`) Remove undocumented and untested looking in session for possible 'next'
redirect location.
+- (:pr:`xxx`) Improve CSRF documentation and testing. Fix bug where a CSRF failure could
+ return an HTML page even if the request was JSON.
Notes
++++++
@@ -101,6 +103,12 @@ Backwards Compatibility Concerns
This implementation is independent of Werkzeug (and relative Location headers are again the default).
The entire regex option has been removed.
Instead, any user-supplied path used as a redirect is parsed and quoted.
+- JSON error response has changed due to issue with WTForms form-level errors. When WTForms
+ introduced form-level errors they added it to the form.errors response using `None` as a key.
+ When serializing it, it would turn into "null". However, if there is more than one error
+ the default settings for JSON serialization in Flask attempt to sort the keys - which fails
+ with the `None` key. An issue has been filed with WTForms - and maybe it will be changed.
+ Flask-Security now changes any `None` key to `""`.
Version 5.3.3
-------------
diff --git a/docs/patterns.rst b/docs/patterns.rst
index 38bc3821..e0c5c521 100644
--- a/docs/patterns.rst
+++ b/docs/patterns.rst
@@ -212,6 +212,7 @@ Flask-Security strives to support various options for both its endpoints (e.g. `
and the application endpoints (protected with Flask-Security decorators such as :func:`.auth_required`).
If your application just uses forms that are derived from ``Flask-WTF::Flaskform`` - you are done.
+Note that all of Flask-Security's endpoints are form based (regardless of how the request was made).
CSRF: Single-Page-Applications and AJAX/XHR
@@ -387,7 +388,7 @@ CSRF: Pro-Tips
(or clients must use CSRF/session cookie for logging
in then once they have an authentication token, no further need for cookie).
- #) If you enable CSRFProtect(app) and you want to support non-form based JSON requests,
+ #) If you enable CSRFProtect(app) and you want to send request data as JSON,
then you must include the CSRF token in the header (e.g. X-CSRF-Token)
#) You must enable CSRFProtect(app) if you want to accept the CSRF token in the request
diff --git a/flask_security/decorators.py b/flask_security/decorators.py
index 31064802..9e5baf4c 100644
--- a/flask_security/decorators.py
+++ b/flask_security/decorators.py
@@ -5,9 +5,10 @@
Flask-Security decorators module
:copyright: (c) 2012-2019 by Matt Wright.
- :copyright: (c) 2019-2023 by J. Christopher Wagner (jwag).
+ :copyright: (c) 2019-2024 by J. Christopher Wagner (jwag).
:license: MIT, see LICENSE for more details.
"""
+from __future__ import annotations
from collections import namedtuple
import datetime
@@ -39,6 +40,9 @@
url_for_security,
)
+if t.TYPE_CHECKING: # pragma: no cover
+ from flask.typing import ResponseValue
+
# Convenient references
_csrf = LocalProxy(lambda: current_app.extensions["csrf"])
@@ -203,14 +207,14 @@ def _check_http_auth():
return False
-def handle_csrf(method: t.Optional[str]) -> None:
+def handle_csrf(method: str, json_response: bool = False) -> ResponseValue | None:
"""Invoke CSRF protection based on authentication method.
Usually this is called as part of a decorator, but if that isn't
appropriate, endpoint code can call this directly.
If CSRF protection is appropriate, this will call flask_wtf::protect() which
- will raise a ValidationError on CSRF failure.
+ will raise a CSRFError(BadRequest) on CSRF failure.
This routine does nothing if any of these are true:
@@ -220,10 +224,18 @@ def handle_csrf(method: t.Optional[str]) -> None:
#) csrfProtect already checked and accepted the token
+ This means in the default config - CSRF is done as part of form validation
+ not here. Only if the application calls CSRFProtect(app) will this method
+ do anything. Furthermore - since this is called PRIOR to form instantiation
+ if the request is JSON - it MUST send the csrf_token as a header.
+
If the passed in method is not in *SECURITY_CSRF_PROTECT_MECHANISMS* then not only
will no CSRF code be run, but a flag in the current context ``fs_ignore_csrf``
will be set so that downstream code knows to ignore any CSRF checks.
+ Returns None if all ok, returns a Response with JSON error if request
+ wanted JSON - else re-raises the CSRFError exception.
+
.. versionadded:: 3.3.0
"""
if (
@@ -231,13 +243,20 @@ def handle_csrf(method: t.Optional[str]) -> None:
or not current_app.extensions.get("csrf", None)
or g.get("csrf_valid", False)
):
- return
+ return None
if config_value("CSRF_PROTECT_MECHANISMS"):
if method in config_value("CSRF_PROTECT_MECHANISMS"):
- _csrf.protect() # type: ignore
+ try:
+ _csrf.protect() # type: ignore
+ except CSRFError as e:
+ if json_response:
+ payload = json_error_response(errors=e.description)
+ return _security._render_json(payload, 400, None, None)
+ raise
else:
set_request_attr("fs_ignore_csrf", True)
+ return None
def http_auth_required(realm: t.Any) -> DecoratedView:
@@ -255,7 +274,9 @@ def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if _check_http_auth():
- handle_csrf("basic")
+ eresponse = handle_csrf("basic", _security._want_json(request))
+ if eresponse:
+ return eresponse
set_request_attr("fs_authn_via", "basic")
return current_app.ensure_sync(fn)(*args, **kwargs)
r = _security.default_http_auth_realm if callable(realm) else realm
@@ -282,7 +303,9 @@ def auth_token_required(fn: DecoratedView) -> DecoratedView:
@wraps(fn)
def decorated(*args, **kwargs):
if _check_token():
- handle_csrf("token")
+ eresponse = handle_csrf("token", _security._want_json(request))
+ if eresponse:
+ return eresponse
set_request_attr("fs_authn_via", "token")
return current_app.ensure_sync(fn)(*args, **kwargs)
return _security._unauthn_handler(["token"])
@@ -407,7 +430,9 @@ def decorated_view(
# in a session cookie...
if not check_and_update_authn_fresh(within, grace, method):
return _security._reauthn_handler(within, grace)
- handle_csrf(method)
+ eresponse = handle_csrf(method, _security._want_json(request))
+ if eresponse:
+ return eresponse
set_request_attr("fs_authn_via", method)
return current_app.ensure_sync(fn)(*args, **kwargs)
return _security._unauthn_handler(ams, headers=h)
diff --git a/flask_security/templates/security/change_password.html b/flask_security/templates/security/change_password.html
index a34632cb..405de1e6 100644
--- a/flask_security/templates/security/change_password.html
+++ b/flask_security/templates/security/change_password.html
@@ -1,11 +1,12 @@
{% extends "security/base.html" %}
-{% from "security/_macros.html" import render_field_with_errors, render_field %}
+{% from "security/_macros.html" import render_field_with_errors, render_field, render_field_errors, render_form_errors %}
{% block content %}
{% include "security/_messages.html" %}
{{ _fsdomain('Change password') }}
{% endblock content %}
diff --git a/flask_security/templates/security/forgot_password.html b/flask_security/templates/security/forgot_password.html
index 1b07a09d..f5fe04db 100644
--- a/flask_security/templates/security/forgot_password.html
+++ b/flask_security/templates/security/forgot_password.html
@@ -1,12 +1,14 @@
{% extends "security/base.html" %}
-{% from "security/_macros.html" import render_field_with_errors, render_field %}
+{% from "security/_macros.html" import render_field_with_errors, render_field, render_field_errors, render_form_errors %}
{% block content %}
{% include "security/_messages.html" %}
{{ _fsdomain('Send password reset instructions') }}
{% include "security/_menu.html" %}
diff --git a/flask_security/templates/security/reset_password.html b/flask_security/templates/security/reset_password.html
index 1a9b593e..7cd7ef9c 100644
--- a/flask_security/templates/security/reset_password.html
+++ b/flask_security/templates/security/reset_password.html
@@ -1,13 +1,15 @@
{% extends "security/base.html" %}
-{% from "security/_macros.html" import render_field_with_errors, render_field %}
+{% from "security/_macros.html" import render_field_with_errors, render_field, render_field_errors, render_form_errors %}
{% block content %}
{% include "security/_messages.html" %}
{{ _fsdomain('Reset password') }}
{% include "security/_menu.html" %}
diff --git a/flask_security/templates/security/two_factor_setup.html b/flask_security/templates/security/two_factor_setup.html
index 5cf3d01c..8388e74d 100644
--- a/flask_security/templates/security/two_factor_setup.html
+++ b/flask_security/templates/security/two_factor_setup.html
@@ -17,7 +17,7 @@
#}
{% extends "security/base.html" %}
-{% from "security/_macros.html" import render_field_with_errors, render_field, render_field_no_label, render_field_errors %}
+{% from "security/_macros.html" import render_field_with_errors, render_field, render_field_no_label, render_field_errors, render_form_errors %}
{% block content %}
{% include "security/_messages.html" %}
@@ -25,6 +25,7 @@ {{ _fsdomain("Two-factor authentication adds an extra layer of security to y
{{ _fsdomain("In addition to your username and password, you'll need to use a code.") }}
{% else %}
diff --git a/flask_security/utils.py b/flask_security/utils.py
index e5adcc69..d7760dd9 100644
--- a/flask_security/utils.py
+++ b/flask_security/utils.py
@@ -8,6 +8,8 @@
:copyright: (c) 2019-2024 by J. Christopher Wagner (jwag).
:license: MIT, see LICENSE for more details.
"""
+from __future__ import annotations
+
import abc
import base64
from datetime import datetime, timedelta, timezone
@@ -51,9 +53,6 @@
from flask.typing import ResponseValue
from .datastore import User
-SB = t.Union[str, bytes]
-
-
localize_callback = LocalProxy(lambda: _security.i18n_domain.gettext)
FsPermNeed = partial(Need, "fsperm")
@@ -138,7 +137,7 @@ def find_csrf_field_name():
return None
-def is_user_authenticated(user: "User") -> bool:
+def is_user_authenticated(user: User) -> bool:
"""
return True is user is authenticated.
@@ -158,9 +157,9 @@ def is_user_authenticated(user: "User") -> bool:
def login_user(
- user: "User",
- remember: t.Optional[bool] = None,
- authn_via: t.Optional[t.List[str]] = None,
+ user: User,
+ remember: bool | None = None,
+ authn_via: list[str] | None = None,
) -> bool:
"""Perform the login routine.
@@ -323,7 +322,7 @@ def check_and_update_authn_fresh(
return False
-def get_hmac(password: SB) -> bytes:
+def get_hmac(password: str | bytes) -> bytes:
"""Returns a Base64 encoded HMAC+SHA512 of the password signed with
the salt specified by *SECURITY_PASSWORD_SALT*.
@@ -342,7 +341,7 @@ def get_hmac(password: SB) -> bytes:
return base64.b64encode(h.digest())
-def verify_password(password: SB, password_hash: SB) -> bool:
+def verify_password(password: str | bytes, password_hash: str | bytes) -> bool:
"""Returns ``True`` if the password matches the supplied hash.
:param password: A plaintext password to verify
@@ -358,7 +357,7 @@ def verify_password(password: SB, password_hash: SB) -> bool:
return _pwd_context.verify(password, password_hash)
-def verify_and_update_password(password: SB, user: "User") -> bool:
+def verify_and_update_password(password: str | bytes, user: User) -> bool:
"""Returns ``True`` if the password is valid for the specified user.
Additionally, the hashed password in the database is updated if the
@@ -389,7 +388,7 @@ def verify_and_update_password(password: SB, user: "User") -> bool:
return verified
-def hash_password(password: SB) -> t.Any:
+def hash_password(password: str | bytes) -> str:
"""Hash the specified plaintext password.
Unless the hash algorithm (as specified by `SECURITY_PASSWORD_HASH`) is listed in
@@ -682,7 +681,7 @@ def simplify_url(base_url: str, redirect_url: str) -> str:
return redirect_url
-def get_config(app: "Flask") -> t.Dict[str, t.Any]:
+def get_config(app: Flask) -> t.Dict[str, t.Any]:
"""Conveniently get the security configuration for the specified
application without the annoying 'SECURITY_' prefix.
@@ -856,7 +855,7 @@ def check_and_get_token_status(
return expired, invalid, data
-def get_identity_attributes(app: t.Optional["Flask"] = None) -> t.List[str]:
+def get_identity_attributes(app: t.Optional[Flask] = None) -> t.List[str]:
# Return list of keys of identity attributes
# Is it possible to not have any?
app = app or current_app
@@ -867,7 +866,7 @@ def get_identity_attributes(app: t.Optional["Flask"] = None) -> t.List[str]:
def get_identity_attribute(
- attr: str, app: t.Optional["Flask"] = None
+ attr: str, app: t.Optional[Flask] = None
) -> t.Dict[str, t.Any]:
"""Given an user_identity_attribute, return the defining dict.
A bit annoying since USER_IDENTITY_ATTRIBUTES is a list of dict
@@ -956,7 +955,7 @@ def use_double_hash(password_hash=None):
return not (single_hash is True or scheme in single_hash)
-def csrf_cookie_handler(response: "Response") -> "Response":
+def csrf_cookie_handler(response: Response) -> Response:
"""Called at end of every request.
Uses session to track state (set/clear)
@@ -1025,12 +1024,12 @@ def csrf_cookie_handler(response: "Response") -> "Response":
def base_render_json(
- form: "FlaskForm",
+ form: FlaskForm,
include_user: bool = True,
include_auth_token: bool = False,
additional: t.Optional[t.Dict[str, t.Any]] = None,
error_status_code: int = 400,
-) -> "ResponseValue":
+) -> ResponseValue:
"""
This method is called by all views that return JSON responses.
This fills in the response and then calls :meth:`.Security.render_json`
@@ -1078,7 +1077,7 @@ def base_render_json(
def simple_render_json(
additional: t.Optional[t.Dict[str, t.Any]] = None,
-) -> "ResponseValue":
+) -> ResponseValue:
payload = dict(csrf_token=csrf.generate_csrf())
if additional:
payload.update(additional)
@@ -1106,7 +1105,7 @@ def default_want_json(req):
def json_error_response(
errors: t.Optional[t.Union[str, list]] = None,
- field_errors: t.Optional[t.Dict[str, list]] = None,
+ field_errors: t.Optional[t.Dict[str | None, list]] = None,
) -> t.Dict[str, t.Any]:
"""Helper to create an error response.
@@ -1116,7 +1115,9 @@ def json_error_response(
The "field_errors" key which is exactly what is returned from WTForms - namely
a dict of field-name: msg. For form-level errors (WTForms 3.0) the 'field-name' is
- 'None'
+ None - which alas means it isn't sortable and Flask's default JSONProvider
+ sorts keys - so we change that to '__all__' which is what django uses
+ apparently and was suggested as part of WTForms 3.0.
"""
response_json: t.Dict[str, t.Union[list, t.Dict[str, list]]] = dict()
plain_errors = []
@@ -1132,7 +1133,14 @@ def json_error_response(
# we return that, as well as create a simple list of errors.
for e in field_errors.values():
plain_errors.extend(e)
- response_json["field_errors"] = field_errors
+ if None in field_errors.keys():
+ # Ugh - wtforms decided to use None as a key - which json
+ # a) can't sort
+ # b) converts to "null"
+ # Issue filed - maybe they will change it
+ field_errors[""] = field_errors[None]
+ del field_errors[None]
+ response_json["field_errors"] = field_errors # type: ignore
response_json["errors"] = plain_errors
return response_json
diff --git a/flask_security/views.py b/flask_security/views.py
index a83b18f8..f48923c0 100644
--- a/flask_security/views.py
+++ b/flask_security/views.py
@@ -767,6 +767,7 @@ def two_factor_setup():
else:
# Caller is changing their TFA profile. This requires a 'fresh' authentication
+ # N.B unauth_csrf has done the CSRF check already.
if not check_and_update_authn_fresh(
cv("FRESHNESS"),
cv("FRESHNESS_GRACE_PERIOD"),
diff --git a/flask_security/webauthn.py b/flask_security/webauthn.py
index 022bb253..4fdbdb56 100644
--- a/flask_security/webauthn.py
+++ b/flask_security/webauthn.py
@@ -560,8 +560,9 @@ def webauthn_register_response(token: str) -> "ResponseValue":
if _security._want_json(request):
return base_render_json(form)
- if len(form.errors) > 0:
- do_flash(form.errors["credential"][0], "error")
+ if form.errors:
+ for v in form.errors.values():
+ do_flash(v[0], "error")
return redirect(url_for_security("wan_register"))
@@ -742,8 +743,9 @@ def webauthn_signin_response(token: str) -> "ResponseValue":
# Since the response is auto submitted - we go back to
# signin form - for now use flash.
- if form.credential.errors:
- do_flash(form.credential.errors[0], "error")
+ if form.errors:
+ for v in form.errors.values():
+ do_flash(v[0], "error")
return redirect(url_for_security("wan_signin"))
diff --git a/pytest.ini b/pytest.ini
index 8c6a5b42..14552c89 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -15,6 +15,7 @@ markers =
unified_signin
webauthn
flask_async
+ csrf
filterwarnings =
error
diff --git a/tests/conftest.py b/tests/conftest.py
index 72610c67..5bbaedc7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,7 @@
from flask import Flask, Response, jsonify, render_template
from flask import request as flask_request
from flask_mailman import Mail
+from flask_wtf import CSRFProtect
from flask_security import (
MongoEngineUserDatastore,
@@ -175,6 +176,17 @@ def app(request: pytest.FixtureRequest) -> "SecurityFixture":
raise pytest.skip("Requires Babel")
Babel(app)
+ csrf = marker_getter("csrf")
+ if csrf is not None:
+ # without any keys/arguments - this is the default config
+ app.config["WTF_CSRF_ENABLED"] = True
+ if "ignore_unauth" in csrf.kwargs.keys():
+ app.config["WTF_CSRF_CHECK_DEFAULT"] = False
+ app.config["SECURITY_CSRF_IGNORE_UNAUTH_ENDPOINTS"] = True
+ if "csrfprotect" in csrf.kwargs.keys():
+ # This is needed when passing CSRF in header
+ CSRFProtect(app)
+
@app.route("/")
def index():
return render_template("index.html", content="Home Page")
diff --git a/tests/test_changeable.py b/tests/test_changeable.py
index 25042a49..bbc0f44b 100644
--- a/tests/test_changeable.py
+++ b/tests/test_changeable.py
@@ -21,7 +21,9 @@
from flask_security.utils import localize_callback
from tests.test_utils import (
authenticate,
+ check_location,
check_xlation,
+ get_form_input,
get_session,
hash_password,
init_app_with_options,
@@ -666,3 +668,42 @@ def test_pwd_no_normalize(app, client):
headers={"Content-Type": "application/json"},
)
assert response.status_code == 200
+
+
+@pytest.mark.csrf(ignore_unauth=True)
+@pytest.mark.settings(post_change_view="/post_change_view")
+def test_csrf(app, client):
+ # enable CSRF, make sure template shows CSRF errors.
+ authenticate(client)
+ data = {
+ "password": "password",
+ "new_password": "new strong password",
+ "new_password_confirm": "new strong password",
+ }
+ response = client.post("/change", data=data)
+ assert b"The CSRF token is missing" in response.data
+ # Note that we get a CSRF token EVEN for errors - this seems odd
+ # but can't find anything that says its a security issue
+ csrf_token = get_form_input(response, "csrf_token")
+
+ data["csrf_token"] = csrf_token
+ response = client.post("/change", data=data)
+ assert check_location(app, response.location, "/post_change_view")
+
+
+@pytest.mark.csrf(ignore_unauth=True, csrfprotect=True)
+def test_csrf_json(app, client):
+ authenticate(client)
+ data = {
+ "password": "password",
+ "new_password": "new strong password",
+ "new_password_confirm": "new strong password",
+ }
+ response = client.post("/change", json=data)
+ assert response.status_code == 400
+ assert response.json["response"]["errors"][0] == "The CSRF token is missing."
+
+ response = client.get("/change", content_type="application/json")
+ csrf_token = response.json["response"]["csrf_token"]
+ response = client.post("/change", json=data, headers={"X-CSRF-Token": csrf_token})
+ assert response.status_code == 200
diff --git a/tests/test_common.py b/tests/test_common.py
index 2b51027e..da6c9b78 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -258,9 +258,9 @@ def test_generic_response(app, client, get_message):
response = client.post(
"/login", json=dict(email="mattwho@lp.com", password="forgot")
)
- # make sure just 'null' key in errors.
- assert list(response.json["response"]["field_errors"].keys()) == ["null"]
- assert len(response.json["response"]["field_errors"]["null"]) == 1
+ # make sure no field error key.
+ assert list(response.json["response"]["field_errors"].keys()) == [""]
+ assert len(response.json["response"]["field_errors"][""]) == 1
assert response.json["response"]["errors"][0].encode("utf-8") == get_message(
"GENERIC_AUTHN_FAILED"
)
@@ -299,9 +299,9 @@ def test_generic_response_username(app, client, get_message):
assert get_message("GENERIC_AUTHN_FAILED") in response.data
response = client.post("/login", json=dict(username="dude2", password="forgot"))
- # make sure just 'null' key in errors.
- assert list(response.json["response"]["field_errors"].keys()) == ["null"]
- assert len(response.json["response"]["field_errors"]["null"]) == 1
+ # make sure no field error key.
+ assert list(response.json["response"]["field_errors"].keys()) == [""]
+ assert len(response.json["response"]["field_errors"][""]) == 1
assert response.json["response"]["errors"][0].encode("utf-8") == get_message(
"GENERIC_AUTHN_FAILED"
)
diff --git a/tests/test_csrf.py b/tests/test_csrf.py
index 3e4984f6..85e83239 100644
--- a/tests/test_csrf.py
+++ b/tests/test_csrf.py
@@ -4,7 +4,7 @@
CSRF tests
- :copyright: (c) 2019-2023 by J. Christopher Wagner (jwag).
+ :copyright: (c) 2019-2024 by J. Christopher Wagner (jwag).
:license: MIT, see LICENSE for more details.
"""
from contextlib import contextmanager
@@ -16,7 +16,7 @@
from flask_wtf import CSRFProtect
from flask_security import Security, hash_password
-from tests.test_utils import get_session, logout
+from tests.test_utils import get_form_input, get_session, logout
REAL_VALIDATE_CSRF = None
@@ -78,10 +78,7 @@ def json_login(
use_header=False,
remember=None,
):
- """Return tuple (auth_token, csrf_token)
- Note that since this is sent as JSON rather than form that csrfProtect
- won't find token value (since it looks in request.form).
- """
+ # Return tuple (auth_token, csrf_token)
csrf_token = _get_csrf_token(client)
data = dict(email=email, password=password, remember=remember)
@@ -109,16 +106,15 @@ def json_logout(client):
return response
+@pytest.mark.csrf()
def test_login_csrf(app, client):
- app.config["WTF_CSRF_ENABLED"] = True
-
# This shouldn't log in - but return login form with csrf token.
data = dict(email="matt@lp.com", password="password", remember="y")
response = client.post("/login", data=data)
assert response.status_code == 200
- assert b"csrf_token" in response.data
+ assert b"The CSRF token is missing." in response.data
- data["csrf_token"] = _get_csrf_token(client)
+ data["csrf_token"] = get_form_input(response, "csrf_token")
response = client.post("/login", data=data, follow_redirects=True)
assert response.status_code == 200
assert b"Welcome matt" in response.data
@@ -151,9 +147,8 @@ def test_login_csrf_double(app, client):
assert b"Welcome matt" in response.data
+@pytest.mark.csrf()
def test_login_csrf_json(app, client):
- app.config["WTF_CSRF_ENABLED"] = True
-
with mp_validate_csrf() as mp:
auth_token, csrf_token = json_login(client)
assert auth_token
@@ -166,14 +161,8 @@ def test_login_csrf_json(app, client):
assert "csrf_token" not in session
-def test_login_csrf_json_header(app, sqlalchemy_datastore):
- app.config["WTF_CSRF_ENABLED"] = True
- CSRFProtect(app)
- app.security = Security(app=app, datastore=sqlalchemy_datastore)
- create_user(app)
-
- client = app.test_client()
-
+@pytest.mark.csrf(csrfprotect=True)
+def test_login_csrf_json_header(app, client):
with mp_validate_csrf() as mp:
auth_token, csrf_token = json_login(client, use_header=True)
assert auth_token
@@ -215,21 +204,27 @@ def test_login_csrf_unauth_double(app, client, get_message):
)
+@pytest.mark.csrf()
@pytest.mark.recoverable()
def test_reset(app, client):
"""Test that form-based CSRF works for /reset"""
- app.config["WTF_CSRF_ENABLED"] = True
+ response = client.get("/reset", content_type="application/json")
+ csrf_token = response.json["response"]["csrf_token"]
with mp_validate_csrf() as mp:
data = dict(email="matt@lp.com")
# should fail - no CSRF token
response = client.post("/reset", content_type="application/json", json=data)
assert response.status_code == 400
+ assert response.json["response"]["errors"][0] == "The CSRF token is missing."
+ # test template also has error
+ response = client.post("/reset", data=data)
+ assert b"The CSRF token is missing" in response.data
- data["csrf_token"] = _get_csrf_token(client)
+ data["csrf_token"] = csrf_token
response = client.post("/reset", content_type="application/json", json=data)
assert response.status_code == 200
- assert mp.success == 1 and mp.failure == 1
+ assert mp.success == 1 and mp.failure == 2
@pytest.mark.recoverable()
@@ -246,6 +241,7 @@ def test_cp_reset(app, client):
# should fail - no CSRF token
response = client.post("/reset", content_type="application/json", json=data)
assert response.status_code == 400
+ assert response.json["response"]["errors"][0] == "The CSRF token is missing."
csrf_token = _get_csrf_token(client)
response = client.post(
@@ -261,15 +257,11 @@ def test_cp_reset(app, client):
@pytest.mark.changeable()
-def test_cp_with_token(app, sqlalchemy_datastore):
+@pytest.mark.csrf(csrfprotect=True)
+def test_cp_with_token(app, client):
# Make sure can use returned CSRF-Token in Header.
- app.config["WTF_CSRF_ENABLED"] = True
- CSRFProtect(app)
- app.security = Security(app=app, datastore=sqlalchemy_datastore)
- create_user(app)
-
- client = app.test_client()
-
+ # Since the csrf token isn't in the form - must enable app-wide CSRF
+ # using CSRFProtect() - as the above mark does.
auth_token, csrf_token = json_login(client, use_header=True)
# make sure returned csrf_token works in header.
@@ -442,17 +434,11 @@ def test_csrf_cookie(app, sqlalchemy_datastore):
assert not client.get_cookie("XSRF-Token")
+@pytest.mark.csrf(csrfprotect=True)
@pytest.mark.settings(CSRF_COOKIE={"key": "XSRF-Token"})
@pytest.mark.changeable()
-def test_cp_with_token_cookie(app, sqlalchemy_datastore):
- # Make sure can use returned CSRF-Token cookie in Header when changing password.
- app.config["WTF_CSRF_ENABLED"] = True
- CSRFProtect(app)
- app.security = Security(app=app, datastore=sqlalchemy_datastore)
- create_user(app)
-
- client = app.test_client()
-
+def test_cp_with_token_cookie(app, client):
+ # Make sure can use returned CSRF-Token cookie in Header when changing password
json_login(client, use_header=True)
# make sure returned csrf_token works in header.
@@ -517,19 +503,13 @@ def test_cp_with_token_cookie_expire(app, sqlalchemy_datastore):
assert not client.get_cookie("XSRF-Token")
+@pytest.mark.csrf(csrfprotect=True)
@pytest.mark.settings(
CSRF_COOKIE_NAME="XSRF-Token", CSRF_COOKIE_REFRESH_EACH_REQUEST=True
)
@pytest.mark.changeable()
-def test_cp_with_token_cookie_refresh(app, sqlalchemy_datastore):
+def test_cp_with_token_cookie_refresh(app, client):
# Test CSRF_COOKIE_REFRESH_EACH_REQUEST
- app.config["WTF_CSRF_ENABLED"] = True
- CSRFProtect(app)
- app.security = Security(app=app, datastore=sqlalchemy_datastore)
- create_user(app)
-
- client = app.test_client()
-
json_login(client, use_header=True)
# make sure returned csrf_token works in header.
@@ -564,25 +544,14 @@ def test_cp_with_token_cookie_refresh(app, sqlalchemy_datastore):
assert not client.get_cookie("XSRF-Token")
+@pytest.mark.csrf(csrfprotect=True)
@pytest.mark.settings(CSRF_COOKIE_NAME="XSRF-Token")
@pytest.mark.changeable()
-def test_remember_login_csrf_cookie(app, sqlalchemy_datastore):
+def test_remember_login_csrf_cookie(app, client):
# Test csrf cookie upon resuming a remember session
- app.config["WTF_CSRF_ENABLED"] = True
- CSRFProtect(app)
- app.security = Security(app=app, datastore=sqlalchemy_datastore)
- create_user(app)
-
- client = app.test_client()
-
# Login with remember_token generation
json_login(client, use_header=True, remember=True)
- # csrf_cookie = [c for c in client.cookie_jar if c.name == "XSRF-Token"][0]
- # session_cookie = [c for c in client.cookie_jar if c.name == "session"][0]
- # Delete session and csrf cookie - we should always get new ones
- # client.delete_cookie(csrf_cookie.domain, csrf_cookie.name)
- # client.delete_cookie(session_cookie.domain, session_cookie.name)
client.delete_cookie("XSRF-Token")
client.delete_cookie("session")
diff --git a/tests/test_misc.py b/tests/test_misc.py
index ef67a170..0525615a 100644
--- a/tests/test_misc.py
+++ b/tests/test_misc.py
@@ -886,10 +886,12 @@ def test_json_form_errors(app, client):
"""Test wtforms form level errors are correctly sent via json"""
with app.test_request_context():
form = ChangePasswordForm()
+ form.validate()
form.form_errors.append("I am an error")
response = base_render_json(form)
- assert len(response.json["response"]["errors"]) == 1
- assert response.json["response"]["errors"][0] == "I am an error"
+ error_list = response.json["response"]["errors"]
+ assert len(error_list) == 3
+ assert "I am an error" in error_list
def test_method_view(app, client):
@@ -1313,7 +1315,7 @@ def test_open_redirect(app, client, get_message):
(r"/\github.com", "/%5Cgithub.com"),
(r"\/github.com", "%5C/github.com"),
("//github.com", ""),
- ("\t//github.com", ""),
+ ("\t//github.com", "%09//github.com"),
]
for i, o in test_urls:
data = dict(email="matt@lp.com", password="password", next=i)
diff --git a/tests/test_oauthglue.py b/tests/test_oauthglue.py
index 8b40b920..c151e4bc 100644
--- a/tests/test_oauthglue.py
+++ b/tests/test_oauthglue.py
@@ -87,7 +87,7 @@ def test_github(app, sqlalchemy_datastore, get_message):
# make sure required CSRF
response = client.post(github_url, follow_redirects=False)
- assert "The CSRF token is missing"
+ assert b"The CSRF token is missing" in response.data
response = client.post(
github_url, data=dict(csrf_token=csrf_token), follow_redirects=False
diff --git a/tests/test_recoverable.py b/tests/test_recoverable.py
index b0288a83..6cf0d174 100644
--- a/tests/test_recoverable.py
+++ b/tests/test_recoverable.py
@@ -3,6 +3,9 @@
~~~~~~~~~~~~~~~~
Recoverable functionality tests
+
+ :copyright: (c) 2019-2024 by J. Christopher Wagner (jwag).
+ :license: MIT, see LICENSE for more details.
"""
import re
@@ -17,6 +20,8 @@
authenticate,
capture_flashes,
capture_reset_password_requests,
+ check_location,
+ get_form_input,
logout,
populate_data,
)
@@ -761,3 +766,29 @@ async def on_instructions_sent(myapp, **kwargs):
)
assert not response.json["response"]
assert len(recorded_resets) == 1
+
+
+@pytest.mark.csrf()
+@pytest.mark.settings(post_reset_view="/post_reset_view")
+def test_csrf(app, client, get_message):
+ response = client.get("/reset")
+ csrf_token = get_form_input(response, "csrf_token")
+ with capture_reset_password_requests() as requests:
+ client.post(
+ "/reset",
+ data=dict(email="joe@lp.com", csrf_token=csrf_token),
+ follow_redirects=True,
+ )
+ token = requests[0]["token"]
+
+ # use the token - no CSRF so shouldn't work
+ data = {"password": "mypassword", "password_confirm": "mypassword"}
+ response = client.post(
+ "/reset/" + token,
+ data=data,
+ )
+ assert b"The CSRF token is missing" in response.data
+
+ data["csrf_token"] = csrf_token
+ response = client.post(f"/reset/{token}", data=data)
+ assert check_location(app, response.location, "/post_reset_view")
diff --git a/tests/test_registerable.py b/tests/test_registerable.py
index 42dd48c1..15ec3c15 100644
--- a/tests/test_registerable.py
+++ b/tests/test_registerable.py
@@ -516,7 +516,7 @@ def test_username(app, client, get_message):
assert response.status_code == 400
assert (
get_message("USER_DOES_NOT_EXIST")
- == response.json["response"]["field_errors"]["null"][0].encode()
+ == response.json["response"]["field_errors"][""][0].encode()
)
# login using us-signin
diff --git a/tests/test_two_factor.py b/tests/test_two_factor.py
index 4d06ab5a..7525b2a0 100644
--- a/tests/test_two_factor.py
+++ b/tests/test_two_factor.py
@@ -30,6 +30,7 @@
check_location,
check_xlation,
get_form_action,
+ get_form_input,
get_session,
is_authenticated,
logout,
@@ -1469,3 +1470,38 @@ def test_xlation(app, client, get_message_local):
with app.test_request_context():
existing = "Méthode à deux facteurs actuellement configurée : authentificateur"
assert markupsafe.escape(existing).encode() in response.data
+
+
+@pytest.mark.csrf(ignore_unauth=True)
+@pytest.mark.settings(two_factor_post_setup_view="/post_setup_view")
+def test_setup_csrf(app, client):
+ # Verify /tf-setup properly handles CSRF and template relays CSRF errors
+ tf_authenticate(app, client)
+ response = client.get("tf-setup")
+ assert b"Disable" in response.data
+ csrf_token = get_form_input(response, "csrf_token")
+
+ response = client.post("tf-setup", data=dict(setup="disable"))
+ assert b"The CSRF token is missing" in response.data
+
+ response = client.post(
+ "tf-setup", data=dict(setup="disable", csrf_token=csrf_token)
+ )
+ assert check_location(app, response.location, "/post_setup_view")
+
+
+@pytest.mark.csrf(ignore_unauth=True, csrfprotect=True)
+def test_setup_csrf_header(app, client):
+ # Test that can setup using csrf token in header
+ tf_authenticate(app, client)
+ response = client.get("tf-setup", json=dict())
+ csrf_token = response.json["response"]["csrf_token"]
+
+ response = client.post("tf-setup", json=dict(setup="disable"))
+ assert response.status_code == 400
+ assert response.json["response"]["errors"][0] == "The CSRF token is missing."
+
+ response = client.post(
+ "tf-setup", json=dict(setup="disable"), headers={"X-CSRF-Token": csrf_token}
+ )
+ assert response.status_code == 200
diff --git a/tests/test_unified_signin.py b/tests/test_unified_signin.py
index b0ba6e00..b9435179 100644
--- a/tests/test_unified_signin.py
+++ b/tests/test_unified_signin.py
@@ -2030,9 +2030,9 @@ def test_generic_response(app, client, get_message):
data = dict(identity="matt@lp.com", code="12345")
response = client.post("/us-signin", json=data)
assert response.status_code == 400
- assert list(response.json["response"]["field_errors"].keys()) == ["null"]
- assert len(response.json["response"]["field_errors"]["null"]) == 1
- assert response.json["response"]["field_errors"]["null"][0].encode(
+ assert list(response.json["response"]["field_errors"].keys()) == [""]
+ assert len(response.json["response"]["field_errors"][""]) == 1
+ assert response.json["response"]["field_errors"][""][0].encode(
"utf-8"
) == get_message("GENERIC_AUTHN_FAILED")
assert response.json["response"]["errors"][0].encode("utf-8") == get_message(
@@ -2047,9 +2047,9 @@ def test_generic_response(app, client, get_message):
data = dict(identity="matt2@lp.com", code="12345")
response = client.post("/us-signin", json=data)
assert response.status_code == 400
- assert list(response.json["response"]["field_errors"].keys()) == ["null"]
- assert len(response.json["response"]["field_errors"]["null"]) == 1
- assert response.json["response"]["field_errors"]["null"][0].encode(
+ assert list(response.json["response"]["field_errors"].keys()) == [""]
+ assert len(response.json["response"]["field_errors"][""]) == 1
+ assert response.json["response"]["field_errors"][""][0].encode(
"utf-8"
) == get_message("GENERIC_AUTHN_FAILED")
assert response.json["response"]["errors"][0].encode("utf-8") == get_message(
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 9a6027f9..e3f85742 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -35,9 +35,17 @@
def authenticate(
- client, email="matt@lp.com", password="password", endpoint=None, **kwargs
+ client,
+ email="matt@lp.com",
+ password="password",
+ endpoint=None,
+ csrf=False,
+ **kwargs,
):
data = dict(email=email, password=password, remember="y")
+ if csrf:
+ response = client.get(endpoint or "/login")
+ data["csrf_token"] = get_form_input(response, "csrf_token")
return client.post(endpoint or "/login", data=data, **kwargs)
diff --git a/tests/test_webauthn.py b/tests/test_webauthn.py
index 00fee8b6..d2460f4a 100644
--- a/tests/test_webauthn.py
+++ b/tests/test_webauthn.py
@@ -23,6 +23,7 @@
FakeSerializer,
authenticate,
capture_flashes,
+ check_location,
get_existing_session,
get_form_action,
get_form_input,
@@ -185,8 +186,12 @@ def origin(self):
return "http://localhost:5001"
-def _register_start(client, name="testr1", usage="secondary", endpoint="wan-register"):
- response = client.post(endpoint, data=dict(name=name, usage=usage))
+def _register_start(
+ client, name="testr1", usage="secondary", endpoint="wan-register", csrf_token=None
+):
+ response = client.post(
+ endpoint, data=dict(name=name, usage=usage, csrf_token=csrf_token)
+ )
matcher = re.match(
r".*handleRegister\(\'(.*)\'\).*",
response.data.decode("utf-8"),
@@ -230,8 +235,11 @@ def _signin_start(
client,
identity=None,
endpoint="wan-signin",
+ csrf_token=None,
):
- response = client.post(endpoint, data=dict(identity=identity))
+ response = client.post(
+ endpoint, data=dict(identity=identity, csrf_token=csrf_token)
+ )
matcher = re.match(
r".*handleSignin\(\'(.*)\'\).*",
response.data.decode("utf-8"),
@@ -1713,3 +1721,50 @@ async def wan_delete(sender, user, name, **extra_args):
response = client.post(
"/wan-delete", data=dict(name="testr1"), follow_redirects=True
)
+
+
+@pytest.mark.csrf()
+@pytest.mark.settings(
+ webauthn_util_cls=HackWebauthnUtil,
+ wan_post_register_view="/done-registration",
+ post_login_view="/post-login",
+)
+def test_csrf(app, client, get_message):
+ response = client.get("/login")
+ csrf_token = get_form_input(response, "csrf_token")
+ authenticate(client, csrf=True)
+
+ register_options, response_url = _register_start(
+ client, usage="first", csrf_token=csrf_token
+ )
+ data = dict(credential=json.dumps(REG_DATA1))
+ response = client.post(response_url, data=data, follow_redirects=True)
+ assert (
+ b"The CSRF token is missing" in response.data
+ ) # this should have been flashed
+
+ data["csrf_token"] = csrf_token
+ response = client.post(response_url, data=data)
+ assert check_location(app, response.location, "/done-registration")
+ logout(client)
+
+ # use old csrf_token - should fail and we should get the error in the template
+ response = client.post(
+ "wan-signin", data=dict(identity="matt@lp.com", csrf_token=csrf_token)
+ )
+ assert b"The CSRF tokens do not match." in response.data
+
+ response = client.get("/wan-signin")
+ csrf_token = get_form_input(response, "csrf_token")
+ signin_options, response_url = _signin_start(
+ client, "matt@lp.com", csrf_token=csrf_token
+ )
+ data = dict(credential=json.dumps(SIGNIN_DATA1))
+ response = client.post(response_url, data=data, follow_redirects=True)
+ assert (
+ b"The CSRF token is missing" in response.data
+ ) # this should have been flashed
+
+ data["csrf_token"] = csrf_token
+ response = client.post(response_url, data=data)
+ assert check_location(app, response.location, "/post-login")