Skip to content

Commit

Permalink
Support redirecting to page according to query after authenticated (#…
Browse files Browse the repository at this point in the history
…6736)

#### What type of PR is this?

/kind improvement
/area core
/milestone 2.20.0

#### What this PR does / why we need it:

This PR supports query `redirect_uri` to control where to redirect after authenticated.

#### Which issue(s) this PR fixes:

Fixes #6720

#### Special notes for your reviewer:

Every step below needs you logging out.

1. Try to request <http://localhost:8090/console/login?redirect_uri=/xxx
2. Try to request <http://localhost:8090/login?redirect_uri=/xxx
3. Try to request <http://localhost:8090/console/posts

#### Does this PR introduce a user-facing change?

```release-note
None
```
  • Loading branch information
JohnNiang authored Sep 30, 2024
1 parent 8a9b954 commit db65dd3
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.NegatedServerWebExchangeMatcher;
Expand All @@ -36,6 +37,7 @@
import run.halo.app.infra.AnonymousUserConst;
import run.halo.app.infra.properties.HaloProperties;
import run.halo.app.security.DefaultUserDetailService;
import run.halo.app.security.HaloServerRequestCache;
import run.halo.app.security.authentication.CryptoService;
import run.halo.app.security.authentication.SecurityConfigurer;
import run.halo.app.security.authentication.impl.RsaKeyService;
Expand Down Expand Up @@ -64,7 +66,8 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http,
ServerSecurityContextRepository securityContextRepository,
ReactiveExtensionClient client,
CryptoService cryptoService,
HaloProperties haloProperties) {
HaloProperties haloProperties,
ServerRequestCache serverRequestCache) {

var pathMatcher = pathMatchers("/**");
var staticResourcesMatcher = pathMatchers(HttpMethod.GET,
Expand Down Expand Up @@ -134,14 +137,20 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http,
haloProperties.getSecurity().getReferrerOptions().getPolicy())
)
.hsts(hstsSpec -> hstsSpec.includeSubdomains(false))
);
)
.requestCache(spec -> spec.requestCache(serverRequestCache));

// Integrate with other configurers separately
securityConfigurers.orderedStream()
.forEach(securityConfigurer -> securityConfigurer.configure(http));
return http.build();
}

@Bean
ServerRequestCache serverRequestCache() {
return new HaloServerRequestCache();
}

@Bean
ServerSecurityContextRepository securityContextRepository() {
return new WebSessionServerSecurityContextRepository();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import run.halo.app.content.PostContentService;
import run.halo.app.core.extension.service.AttachmentService;
import run.halo.app.extension.DefaultSchemeManager;
Expand Down Expand Up @@ -79,6 +80,12 @@ public static ApplicationContext create(ApplicationContext rootContext) {
.ifUnique(rateLimiterRegistry ->
beanFactory.registerSingleton("rateLimiterRegistry", rateLimiterRegistry)
);

// Authentication plugins may need this RequestCache to handle successful login redirect
rootContext.getBeanProvider(ServerRequestCache.class)
.ifUnique(serverRequestCache ->
beanFactory.registerSingleton("serverRequestCache", serverRequestCache)
);
// TODO add more shared instance here

sharedContext.refresh();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
import org.springframework.web.server.ServerWebExchange;
Expand All @@ -30,17 +31,19 @@ public class DefaultServerAuthenticationEntryPoint implements ServerAuthenticati

private final RedirectServerAuthenticationEntryPoint redirectEntryPoint;

public DefaultServerAuthenticationEntryPoint() {
this.redirectEntryPoint =
public DefaultServerAuthenticationEntryPoint(ServerRequestCache serverRequestCache) {
var entryPoint =
new RedirectServerAuthenticationEntryPoint("/login?authentication_required");
entryPoint.setRequestCache(serverRequestCache);
this.redirectEntryPoint = entryPoint;
}

@Override
public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException ex) {
return xhrMatcher.matches(exchange)
.filter(MatchResult::isMatch)
.switchIfEmpty(
Mono.defer(() -> this.redirectEntryPoint.commence(exchange, ex)).then(Mono.empty())
Mono.defer(() -> this.redirectEntryPoint.commence(exchange, ex).then(Mono.empty()))
)
.flatMap(match -> Mono.defer(
() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.springframework.security.web.server.authentication.AuthenticationConverterServerWebExchangeMatcher;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.security.web.server.authorization.ServerWebExchangeDelegatingServerAccessDeniedHandler;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.stereotype.Component;
Expand All @@ -24,10 +25,14 @@ public class ExceptionSecurityConfigurer implements SecurityConfigurer {

private final ServerResponse.Context context;

private final ServerRequestCache serverRequestCache;

public ExceptionSecurityConfigurer(MessageSource messageSource,
ServerResponse.Context context) {
ServerResponse.Context context,
ServerRequestCache serverRequestCache) {
this.messageSource = messageSource;
this.context = context;
this.serverRequestCache = serverRequestCache;
}

@Override
Expand Down Expand Up @@ -59,7 +64,7 @@ public void configure(ServerHttpSecurity http) {
));
entryPoints.add(new DelegatingServerAuthenticationEntryPoint.DelegateEntry(
exchange -> ServerWebExchangeMatcher.MatchResult.match(),
new DefaultServerAuthenticationEntryPoint()
new DefaultServerAuthenticationEntryPoint(serverRequestCache)
));

exception.authenticationEntryPoint(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package run.halo.app.security;

import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers.pathMatchers;

import java.net.URI;
import java.util.Collections;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.RequestPath;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.NegatedServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;

/**
* Halo server request cache implementation for saving redirect URI from query.
*
* @author johnniang
*/
public class HaloServerRequestCache extends WebSessionServerRequestCache {

/**
* Currently, we have no idea to customize the sessionAttributeName in
* WebSessionServerRequestCache, so we have to copy the attr into here.
*/
private static final String DEFAULT_SAVED_REQUEST_ATTR = "SPRING_SECURITY_SAVED_REQUEST";

private static final String REDIRECT_URI_QUERY = "redirect_uri";

private final String sessionAttrName = DEFAULT_SAVED_REQUEST_ATTR;

public HaloServerRequestCache() {
super();
setSaveRequestMatcher(createDefaultRequestMatcher());
}

@Override
public Mono<Void> saveRequest(ServerWebExchange exchange) {
var redirectUriQuery = exchange.getRequest().getQueryParams().getFirst(REDIRECT_URI_QUERY);
if (StringUtils.isNotBlank(redirectUriQuery)) {
var redirectUri = URI.create(redirectUriQuery);
return saveRedirectUri(exchange, redirectUri);
}
return super.saveRequest(exchange);
}

@Override
public Mono<URI> getRedirectUri(ServerWebExchange exchange) {
return super.getRedirectUri(exchange);
}

@Override
public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) {
return super.removeMatchingRequest(exchange);
}

private Mono<Void> saveRedirectUri(ServerWebExchange exchange, URI redirectUri) {
var requestPath = exchange.getRequest().getPath();
var redirectPath = RequestPath.parse(redirectUri, requestPath.contextPath().value());
var query = redirectUri.getRawQuery();
var finalRedirect =
redirectPath.pathWithinApplication() + (query == null ? "" : "?" + query);
return exchange.getSession()
.map(WebSession::getAttributes)
.doOnNext(attributes -> attributes.put(this.sessionAttrName, finalRedirect))
.then();
}

private static ServerWebExchangeMatcher createDefaultRequestMatcher() {
var get = pathMatchers(HttpMethod.GET, "/**");
var notFavicon = new NegatedServerWebExchangeMatcher(
pathMatchers(
"/favicon.*", "/login/**", "/signup/**", "/password-reset/**", "/challenges/**"
));
var html = new MediaTypeServerWebExchangeMatcher(MediaType.TEXT_HTML);
html.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
return new AndServerWebExchangeMatcher(get, notFavicon, html);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import reactor.core.publisher.Mono;
import run.halo.app.security.LoginHandlerEnhancer;

Expand All @@ -13,11 +14,14 @@ public class TotpAuthenticationSuccessHandler implements ServerAuthenticationSuc

private final LoginHandlerEnhancer loginEnhancer;

private final ServerAuthenticationSuccessHandler successHandler =
new RedirectServerAuthenticationSuccessHandler("/uc");
private final ServerAuthenticationSuccessHandler successHandler;

public TotpAuthenticationSuccessHandler(LoginHandlerEnhancer loginEnhancer) {
public TotpAuthenticationSuccessHandler(LoginHandlerEnhancer loginEnhancer,
ServerRequestCache serverRequestCache) {
this.loginEnhancer = loginEnhancer;
var successHandler = new RedirectServerAuthenticationSuccessHandler("/uc");
successHandler.setRequestCache(serverRequestCache);
this.successHandler = successHandler;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationFailureHandler;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.stereotype.Component;
import run.halo.app.security.LoginHandlerEnhancer;
import run.halo.app.security.authentication.SecurityConfigurer;
Expand All @@ -24,13 +25,17 @@ public class TwoFactorAuthSecurityConfigurer implements SecurityConfigurer {

private final LoginHandlerEnhancer loginHandlerEnhancer;

private final ServerRequestCache serverRequestCache;

public TwoFactorAuthSecurityConfigurer(
ServerSecurityContextRepository securityContextRepository,
TotpAuthService totpAuthService, LoginHandlerEnhancer loginHandlerEnhancer
TotpAuthService totpAuthService, LoginHandlerEnhancer loginHandlerEnhancer,
ServerRequestCache serverRequestCache
) {
this.securityContextRepository = securityContextRepository;
this.totpAuthService = totpAuthService;
this.loginHandlerEnhancer = loginHandlerEnhancer;
this.serverRequestCache = serverRequestCache;
}

@Override
Expand All @@ -43,7 +48,7 @@ public void configure(ServerHttpSecurity http) {
filter.setSecurityContextRepository(securityContextRepository);
filter.setServerAuthenticationConverter(new TotpCodeAuthenticationConverter());
filter.setAuthenticationSuccessHandler(
new TotpAuthenticationSuccessHandler(loginHandlerEnhancer)
new TotpAuthenticationSuccessHandler(loginHandlerEnhancer, serverRequestCache)
);
filter.setAuthenticationFailureHandler(
new RedirectServerAuthenticationFailureHandler("/challenges/two-factor/totp?error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import java.net.URI;
import org.springframework.context.MessageSource;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.server.DefaultServerRedirectStrategy;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
import org.springframework.security.web.server.ServerRedirectStrategy;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.reactive.function.server.ServerResponse;
import org.springframework.web.server.ServerWebExchange;
Expand All @@ -18,10 +19,13 @@ public class TwoFactorAuthenticationEntryPoint implements ServerAuthenticationEn
.flatMap(a -> ServerWebExchangeMatcher.MatchResult.match())
.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());

private static final String REDIRECT_LOCATION = "/challenges/two-factor/totp";
private static final URI REDIRECT_LOCATION = URI.create("/challenges/two-factor/totp");

private final RedirectServerAuthenticationEntryPoint redirectEntryPoint =
new RedirectServerAuthenticationEntryPoint(REDIRECT_LOCATION);
/**
* Because we don't want to cache the request before redirecting to the 2FA page,
* ServerRedirectStrategy is used to redirect the request.
*/
private final ServerRedirectStrategy redirectStrategy = new DefaultServerRedirectStrategy();

private final MessageSource messageSource;

Expand All @@ -45,10 +49,12 @@ public TwoFactorAuthenticationEntryPoint(MessageSource messageSource,
public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException ex) {
return XHR_MATCHER.matches(exchange)
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.switchIfEmpty(redirectEntryPoint.commence(exchange, ex).then(Mono.empty()))
.switchIfEmpty(
redirectStrategy.sendRedirect(exchange, REDIRECT_LOCATION).then(Mono.empty())
)
.flatMap(isXhr -> {
var errorResponse = Exceptions.createErrorResponse(
new TwoFactorAuthRequiredException(URI.create(REDIRECT_LOCATION)),
new TwoFactorAuthRequiredException(REDIRECT_LOCATION),
null, exchange, messageSource);
return ServerResponse.status(errorResponse.getStatusCode())
.bodyValue(errorResponse.getBody())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.annotation.Bean;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
Expand All @@ -18,6 +19,7 @@
import run.halo.app.infra.actuator.GlobalInfoService;
import run.halo.app.plugin.PluginConst;
import run.halo.app.security.AuthProviderService;
import run.halo.app.security.HaloServerRequestCache;
import run.halo.app.security.authentication.CryptoService;

/**
Expand All @@ -35,6 +37,8 @@ class PreAuthLoginEndpoint {

private final AuthProviderService authProviderService;

private final ServerRequestCache serverRequestCache = new HaloServerRequestCache();

PreAuthLoginEndpoint(CryptoService cryptoService, GlobalInfoService globalInfoService,
AuthProviderService authProviderService) {
this.cryptoService = cryptoService;
Expand All @@ -46,6 +50,7 @@ class PreAuthLoginEndpoint {
RouterFunction<ServerResponse> preAuthLoginEndpoints() {
return RouterFunctions.nest(path("/login"), RouterFunctions.route()
.GET("", request -> {
// TODO get redirect URI and cache it
var exchange = request.exchange();
var contextPath = exchange.getRequest().getPath().contextPath().value();
var publicKey = cryptoService.readPublicKey()
Expand Down Expand Up @@ -78,15 +83,17 @@ RouterFunction<ServerResponse> preAuthLoginEndpoints() {
.filter(ap -> !Objects.equals(loginMethod, ap.getMetadata().getName()))
.cache();

return ServerResponse.ok().render("login", Map.of(
"action", contextPath + "/login",
"publicKey", publicKey,
"globalInfo", globalInfo,
"authProvider", authProvider,
"fragmentTemplateName", fragmentTemplateName,
"socialAuthProviders", socialAuthProviders,
"formAuthProviders", formAuthProviders
// TODO Add more models here
return serverRequestCache.saveRequest(exchange).then(Mono.defer(() ->
ServerResponse.ok().render("login", Map.of(
"action", contextPath + "/login",
"publicKey", publicKey,
"globalInfo", globalInfo,
"authProvider", authProvider,
"fragmentTemplateName", fragmentTemplateName,
"socialAuthProviders", socialAuthProviders,
"formAuthProviders", formAuthProviders
// TODO Add more models here
))
));
})
.build());
Expand Down
Loading

0 comments on commit db65dd3

Please sign in to comment.