Skip to content

Commit

Permalink
Add CachingRelyingPartyRegistrationRepository
Browse files Browse the repository at this point in the history
Closes gh-15341
  • Loading branch information
jzheaux committed Jul 2, 2024
1 parent 1e29003 commit 7b39800
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 0 deletions.
51 changes: 51 additions & 0 deletions docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,57 @@ class MyCustomSecurityConfiguration {
A relying party can be multi-tenant by registering more than one relying party in the `RelyingPartyRegistrationRepository`.
====

[[servlet-saml2login-relyingpartyregistrationrepository-caching]]
If you want your metadata to be refreshable on a periodic basis, you can wrap your repository in `CachingRelyingPartyRegistrationRepository` like so:

.Caching Relying Party Registration Repository
[tabs]
======
Java::
+
[source,java,role="primary"]
----
@Configuration
@EnableWebSecurity
public class MyCustomSecurityConfiguration {
@Bean
public RelyingPartyRegistrationRepository registrations(CacheManager cacheManager) {
Supplier<IterableRelyingPartyRegistrationRepository> delegate = () ->
new InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations
.fromMetadataLocation("https://idp.example.org/ap/metadata")
.registrationId("ap").build());
CachingRelyingPartyRegistrationRepository registrations =
new CachingRelyingPartyRegistrationRepository(delegate);
registrations.setCache(cacheManager.getCache("my-cache-name"));
return registrations;
}
}
----
Kotlin::
+
[source,kotlin,role="secondary"]
----
@Configuration
@EnableWebSecurity
class MyCustomSecurityConfiguration {
@Bean
fun registrations(cacheManager: CacheManager): RelyingPartyRegistrationRepository {
val delegate = Supplier<IterableRelyingPartyRegistrationRepository> {
InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations
.fromMetadataLocation("https://idp.example.org/ap/metadata")
.registrationId("ap").build())
}
val registrations = CachingRelyingPartyRegistrationRepository(delegate)
registrations.setCache(cacheManager.getCache("my-cache-name"))
return registrations
}
}
----
======

In this way, the set of `RelyingPartyRegistration`s will refresh based on {spring-framework-reference-url}integration/cache/store-configuration.html[the cache's eviction schedule].

[[servlet-saml2login-relyingpartyregistration]]
== RelyingPartyRegistration
A {security-api-url}org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.html[`RelyingPartyRegistration`]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.registration;

import java.util.Iterator;
import java.util.Spliterator;
import java.util.concurrent.Callable;
import java.util.function.Consumer;

import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.util.Assert;

/**
* An {@link IterableRelyingPartyRegistrationRepository} that lazily queries and caches
* metadata from a backing {@link IterableRelyingPartyRegistrationRepository}. Delegates
* caching policies to Spring Cache.
*
* @author Josh Cummings
* @since 6.4
*/
public final class CachingRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository {

private final Callable<IterableRelyingPartyRegistrationRepository> registrationLoader;

private Cache cache = new ConcurrentMapCache("registrations");

public CachingRelyingPartyRegistrationRepository(Callable<IterableRelyingPartyRegistrationRepository> loader) {
this.registrationLoader = loader;
}

/**
* {@inheritDoc}
*/
@Override
public Iterator<RelyingPartyRegistration> iterator() {
return registrations().iterator();
}

/**
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
return registrations().findByRegistrationId(registrationId);
}

@Override
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
return registrations().findUniqueByAssertingPartyEntityId(entityId);
}

@Override
public void forEach(Consumer<? super RelyingPartyRegistration> action) {
registrations().forEach(action);
}

@Override
public Spliterator<RelyingPartyRegistration> spliterator() {
return registrations().spliterator();
}

private IterableRelyingPartyRegistrationRepository registrations() {
return this.cache.get("registrations", this.registrationLoader);
}

/**
* Use this cache for the completed {@link RelyingPartyRegistration} instances.
*
* <p>
* Defaults to {@link ConcurrentMapCache}, meaning that the registrations are cached
* without expiry. To turn off the cache, use
* {@link org.springframework.cache.support.NoOpCache}.
* @param cache the {@link Cache} to use
*/
public void setCache(Cache cache) {
Assert.notNull(cache, "cache cannot be null");
this.cache = cache;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.registration;

import java.util.concurrent.Callable;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.cache.Cache;

import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

/**
* Tests for {@link CachingRelyingPartyRegistrationRepository}
*/
@ExtendWith(MockitoExtension.class)
public class CachingRelyingPartyRegistrationRepositoryTests {

@Mock
Callable<Iterable<RelyingPartyRegistration>> callable;

@InjectMocks
CachingRelyingPartyRegistrationRepository registrations;

@Test
public void iteratorWhenResolvableThenPopulatesCache() throws Exception {
given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
this.registrations.iterator();
verify(this.callable).call();
this.registrations.iterator();
verifyNoMoreInteractions(this.callable);
}

@Test
public void iteratorWhenExceptionThenPropagates() throws Exception {
given(this.callable.call()).willThrow(IllegalStateException.class);
assertThatExceptionOfType(Cache.ValueRetrievalException.class).isThrownBy(this.registrations::iterator)
.withCauseInstanceOf(IllegalStateException.class);
}

@Test
public void findByRegistrationIdWhenResolvableThenPopulatesCache() throws Exception {
given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
this.registrations.findByRegistrationId("id");
verify(this.callable).call();
this.registrations.findByRegistrationId("id");
verifyNoMoreInteractions(this.callable);
}

@Test
public void findUniqueByAssertingPartyEntityIdWhenResolvableThenPopulatesCache() throws Exception {
given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
this.registrations.findUniqueByAssertingPartyEntityId("id");
verify(this.callable).call();
this.registrations.findUniqueByAssertingPartyEntityId("id");
verifyNoMoreInteractions(this.callable);
}

}

0 comments on commit 7b39800

Please sign in to comment.