Skip to content

Commit

Permalink
Add AuthenticationWebSocketInterceptor
Browse files Browse the repository at this point in the history
See gh-268

Co-authored-by: Josh Cummings <josh.cummings@gmail.com>
  • Loading branch information
rstoyanchev and jzheaux committed May 17, 2024
1 parent f354950 commit 1171aee
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 0 deletions.
1 change: 1 addition & 0 deletions spring-graphql/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {
compileOnly 'jakarta.validation:jakarta.validation-api'

compileOnly 'org.springframework.security:spring-security-core'
compileOnly 'org.springframework.security:spring-security-oauth2-resource-server'

compileOnly 'com.querydsl:querydsl-core'
compileOnly 'org.springframework.data:spring-data-commons'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.graphql.server.support;

import java.util.Map;

import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;

import org.springframework.graphql.server.WebGraphQlRequest;
import org.springframework.graphql.server.WebGraphQlResponse;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;

/**
* Base class for interceptors that extract an {@link Authentication} from
* the payload of a {@code "connection_init"} GraphQL over WebSocket message.
* The authentication is saved in WebSocket attributes from where it is later
* accessed and propagated to subsequent {@code "subscribe"} messages.
*
* @author Joshua Cummings
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public abstract class AbstractAuthenticationWebSocketInterceptor implements WebSocketGraphQlInterceptor {

private static final String AUTHENTICATION_ATTRIBUTE =
AbstractAuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION";


private final AuthenticationExtractor authenticationExtractor;


/**
* Constructor with the strategy to use to extract the authentication value
* from the {@code "connection_init"} message.
* @param authExtractor the extractor to use
*/
public AbstractAuthenticationWebSocketInterceptor(AuthenticationExtractor authExtractor) {
this.authenticationExtractor = authExtractor;
}

@Override
public Mono<Object> handleConnectionInitialization(WebSocketSessionInfo info, Map<String, Object> payload) {
return this.authenticationExtractor.getAuthentication(payload)
.flatMap(this::getSecurityContext)
.doOnNext((securityContext) -> info.getAttributes().put(AUTHENTICATION_ATTRIBUTE, securityContext))
.then(Mono.empty());
}

/**
* Subclasses implement this method to return an authenticated
* {@link SecurityContext} or an error.
* @param authentication the authentication value extracted from the payload
*/
protected abstract Mono<SecurityContext> getSecurityContext(Authentication authentication);

@Override
public Mono<WebGraphQlResponse> intercept(WebGraphQlRequest request, Chain chain) {
if (!(request instanceof WebSocketGraphQlRequest webSocketRequest)) {
return chain.next(request);
}
Map<String, Object> attributes = webSocketRequest.getSessionInfo().getAttributes();
SecurityContext securityContext = (SecurityContext) attributes.get(AUTHENTICATION_ATTRIBUTE);
ContextView contextView = getContextToWrite(securityContext);
return chain.next(request).contextWrite(contextView);
}

/**
* Subclasses implement this to decide how to insert the {@link SecurityContext}
* into the Reactor context of the {@link WebSocketGraphQlInterceptor} chain.
* @param securityContext the {@code SecurityContext} to write to the context
*/
protected abstract ContextView getContextToWrite(SecurityContext securityContext);

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.graphql.server.support;

import java.util.Map;

import reactor.core.publisher.Mono;

import org.springframework.security.core.Authentication;

/**
* Strategy to extract an {@link Authentication} from the payload of a
* {@code "connection_init"} GraphQL over WebSocket message.
*
* @author Joshua Cummings
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public interface AuthenticationExtractor {

/**
* Return the authentication contained in the given payload, or an empty {@code Mono}.
* @param payload the payload to extract the authentication value from
*/
Mono<Authentication> getAuthentication(Map<String, Object> payload);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.graphql.server.support;

import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import reactor.core.publisher.Mono;

import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenError;
import org.springframework.security.oauth2.server.resource.BearerTokenErrors;
import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthenticationToken;
import org.springframework.util.StringUtils;

/**
* {@link AuthenticationExtractor} that extracts a
* <a href="https://datatracker.ietf.org/doc/html/rfc6750#section-1.2">bearer token</a>.
*
* @author Joshua Cummings
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public final class BearerTokenAuthenticationExtractor implements AuthenticationExtractor {

private static final Pattern authorizationPattern =
Pattern.compile("^Bearer (?<token>[a-zA-Z0-9-._~+/]+=*)$", Pattern.CASE_INSENSITIVE);


private final String authorizationKey;


/**
* Constructor that defaults the payload key to use to "Authorization".
*/
public BearerTokenAuthenticationExtractor() {
this("Authorization");
}

/**
* Constructor with the key for the authorization value.
* @param authorizationKey the key under which to look up the authorization
* value in the {@code "connection_init"} payload.
*/
public BearerTokenAuthenticationExtractor(String authorizationKey) {
this.authorizationKey = authorizationKey;
}


@Override
public Mono<Authentication> getAuthentication(Map<String, Object> payload) {
String authorizationValue = (String) payload.get(this.authorizationKey);
if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) {
return Mono.empty();
}

Matcher matcher = authorizationPattern.matcher(authorizationValue);
if (matcher.matches()) {
String token = matcher.group("token");
return Mono.just(new BearerTokenAuthenticationToken(token));
}

BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed");
return Mono.error(new OAuth2AuthenticationException(error));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.graphql.server.webflux;

import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;

import org.springframework.graphql.server.support.AbstractAuthenticationWebSocketInterceptor;
import org.springframework.graphql.server.support.AuthenticationExtractor;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;

/**
* Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with
* the WebFlux GraphQL transport.
*
* @author Joshua Cummings
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {

private final ReactiveAuthenticationManager authenticationManager;


public AuthenticationWebSocketInterceptor(
AuthenticationExtractor extractor, ReactiveAuthenticationManager manager) {

super(extractor);
this.authenticationManager = manager;
}

@Override
protected Mono<SecurityContext> getSecurityContext(Authentication authentication) {
return this.authenticationManager.authenticate(authentication).map(SecurityContextImpl::new);
}

@Override
protected ContextView getContextToWrite(SecurityContext securityContext) {
return ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext));
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.graphql.server.webmvc;

import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import reactor.util.context.ContextView;

import org.springframework.graphql.server.support.AbstractAuthenticationWebSocketInterceptor;
import org.springframework.graphql.server.support.AuthenticationExtractor;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;

/**
* Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with
* the WebMVC GraphQL transport.
*
* @author Joshua Cummings
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {

private final AuthenticationManager authenticationManager;


public AuthenticationWebSocketInterceptor(
AuthenticationManager authManager, AuthenticationExtractor authExtractor) {

super(authExtractor);
this.authenticationManager = authManager;
}

@Override
protected Mono<SecurityContext> getSecurityContext(Authentication authentication) {
Authentication authenticate = this.authenticationManager.authenticate(authentication);
return Mono.just(new SecurityContextImpl(authenticate));
}

@Override
protected ContextView getContextToWrite(SecurityContext securityContext) {
String key = SecurityContext.class.getName(); // match SecurityContextThreadLocalAccessor key
return Context.of(key, securityContext);
}

}

0 comments on commit 1171aee

Please sign in to comment.