Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread-safe ServiceLoader usage #892

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ public String compact() {

if (this.serializer == null) { // try to find one based on the services available
//noinspection unchecked
json(Services.loadFirst(Serializer.class));
json(Services.get(Serializer.class));
}

if (!Collections.isEmpty(claims)) { // normalize so we have one object to deal with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ public JwtParser build() {

if (this.deserializer == null) {
//noinspection unchecked
json(Services.loadFirst(Deserializer.class));
json(Services.get(Deserializer.class));
}
if (this.signingKeyResolver != null && this.signatureVerificationKey != null) {
String msg = "Both a 'signingKeyResolver and a 'verifyWith' key cannot be configured. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public B json(Deserializer<Map<String, ?>> reader) {
public final Parser<T> build() {
if (this.deserializer == null) {
//noinspection unchecked
this.deserializer = Services.loadFirst(Deserializer.class);
this.deserializer = Services.get(Deserializer.class);
}
return doBuild();
}
Expand Down
100 changes: 35 additions & 65 deletions impl/src/main/java/io/jsonwebtoken/impl/lang/Services.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@
*/
package io.jsonwebtoken.impl.lang;

import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ServiceLoader;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import static io.jsonwebtoken.lang.Collections.arrayToList;

/**
* Helper class for loading services from the classpath, using a {@link ServiceLoader}. Decouples loading logic for
* better separation of concerns and testability.
*/
public final class Services {

private static ConcurrentMap<Class<?>, ServiceLoader<?>> SERVICE_CACHE = new ConcurrentHashMap<>();
private static final ConcurrentMap<Class<?>, Object> SERVICES = new ConcurrentHashMap<>();

private static final List<ClassLoaderAccessor> CLASS_LOADER_ACCESSORS = arrayToList(new ClassLoaderAccessor[] {
private static final List<ClassLoaderAccessor> CLASS_LOADER_ACCESSORS = Arrays.asList(new ClassLoaderAccessor[]{
new ClassLoaderAccessor() {
@Override
public ClassLoader getClassLoader() {
Expand All @@ -54,86 +53,57 @@ public ClassLoader getClassLoader() {
}
});

private Services() {}

/**
* Loads and instantiates all service implementation of the given SPI class and returns them as a List.
*
* @param spi The class of the Service Provider Interface
* @param <T> The type of the SPI
* @return An unmodifiable list with an instance of all available implementations of the SPI. No guarantee is given
* on the order of implementations, if more than one.
*/
public static <T> List<T> loadAll(Class<T> spi) {
Assert.notNull(spi, "Parameter 'spi' must not be null.");

ServiceLoader<T> serviceLoader = serviceLoader(spi);
if (serviceLoader != null) {

List<T> implementations = new ArrayList<>();
for (T implementation : serviceLoader) {
implementations.add(implementation);
}
return implementations;
}

throw new UnavailableImplementationException(spi);
private Services() {
}

/**
* Loads the first available implementation the given SPI class from the classpath. Uses the {@link ServiceLoader}
* to find implementations. When multiple implementations are available it will return the first one that it
* encounters. There is no guarantee with regard to ordering.
* Returns the first available implementation for the given SPI class, checking an internal thread-safe cache first,
* and, if not found, using a {@link ServiceLoader} to find implementations. When multiple implementations are
* available it will return the first one that it encounters. There is no guarantee with regard to ordering.
*
* @param spi The class of the Service Provider Interface
* @param <T> The type of the SPI
* @return A new instance of the service.
* @throws UnavailableImplementationException When no implementation the SPI is available on the classpath.
* @return The first available instance of the service.
* @throws UnavailableImplementationException When no implementation of the SPI class can be found.
*/
public static <T> T loadFirst(Class<T> spi) {
Assert.notNull(spi, "Parameter 'spi' must not be null.");

ServiceLoader<T> serviceLoader = serviceLoader(spi);
if (serviceLoader != null) {
return serviceLoader.iterator().next();
public static <T> T get(Class<T> spi) {
lhazlewood marked this conversation as resolved.
Show resolved Hide resolved
// TODO: JDK8, replace this find/putIfAbsent logic with ConcurrentMap.computeIfAbsent
T instance = findCached(spi);
if (instance == null) {
instance = loadFirst(spi); // throws UnavailableImplementationException if not found, which is what we want
SERVICES.putIfAbsent(spi, instance); // cache if not already cached
}

throw new UnavailableImplementationException(spi);
return instance;
}

/**
* Returns a ServiceLoader for <code>spi</code> class, checking multiple classloaders. The ServiceLoader
* will be cached if it contains at least one implementation of the <code>spi</code> class.<BR>
*
* <b>NOTE:</b> Only the first Serviceloader will be cached.
* @param spi The interface or abstract class representing the service loader.
* @return A service loader, or null if no implementations are found
* @param <T> The type of the SPI.
*/
private static <T> ServiceLoader<T> serviceLoader(Class<T> spi) {
// TODO: JDK8, replace this get/putIfAbsent logic with ConcurrentMap.computeIfAbsent
ServiceLoader<T> serviceLoader = (ServiceLoader<T>) SERVICE_CACHE.get(spi);
if (serviceLoader != null) {
return serviceLoader;
private static <T> T findCached(Class<T> spi) {
Assert.notNull(spi, "Service interface cannot be null.");
Object obj = SERVICES.get(spi);
if (obj != null) {
return Assert.isInstanceOf(spi, obj, "Unexpected cached service implementation type.");
}
return null;
}

for (ClassLoaderAccessor classLoaderAccessor : CLASS_LOADER_ACCESSORS) {
serviceLoader = ServiceLoader.load(spi, classLoaderAccessor.getClassLoader());
if (serviceLoader.iterator().hasNext()) {
SERVICE_CACHE.putIfAbsent(spi, serviceLoader);
return serviceLoader;
private static <T> T loadFirst(Class<T> spi) {
for (ClassLoaderAccessor accessor : CLASS_LOADER_ACCESSORS) {
ServiceLoader<T> loader = ServiceLoader.load(spi, accessor.getClassLoader());
Assert.stateNotNull(loader, "JDK ServiceLoader#load should never return null.");
Iterator<T> i = loader.iterator();
Assert.stateNotNull(i, "JDK ServiceLoader#iterator() should never return null.");
if (i.hasNext()) {
return i.next();
}
}

return null;
throw new UnavailableImplementationException(spi);
}

/**
* Clears internal cache of ServiceLoaders. This is useful when testing, or for applications that dynamically
* Clears internal cache of service singletons. This is useful when testing, or for applications that dynamically
* change classloaders.
*/
public static void reload() {
SERVICE_CACHE.clear();
SERVICES.clear();
}

private interface ClassLoaderAccessor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private JwksBridge() {

@SuppressWarnings({"unchecked", "unused"}) // used via reflection by io.jsonwebtoken.security.Jwks
public static String UNSAFE_JSON(Jwk<?> jwk) {
Serializer<Map<String, ?>> serializer = Services.loadFirst(Serializer.class);
Serializer<Map<String, ?>> serializer = Services.get(Serializer.class);
Assert.stateNotNull(serializer, "Serializer lookup failed. Ensure JSON impl .jar is in the runtime classpath.");
NamedSerializer ser = new NamedSerializer("JWK", serializer);
ByteArrayOutputStream out = new ByteArrayOutputStream(512);
Expand Down
4 changes: 2 additions & 2 deletions impl/src/test/groovy/io/jsonwebtoken/JwtsTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class JwtsTest {
}

static def toJson(def o) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
def out = new ByteArrayOutputStream()
serializer.serialize(o, out)
return Strings.utf8(out.toByteArray())
Expand Down Expand Up @@ -1192,7 +1192,7 @@ class JwtsTest {
int j = jws.lastIndexOf('.')
def b64 = jws.substring(i, j)
def json = Strings.utf8(Decoders.BASE64URL.decode(b64))
def deser = Services.loadFirst(Deserializer)
def deser = Services.get(Deserializer)
def m = deser.deserialize(new StringReader(json)) as Map<String,?>

assertEquals aud, m.get('aud') // single string value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import static org.junit.Assert.fail

class RFC7515AppendixETest {

static final Serializer<Map<String, ?>> serializer = Services.loadFirst(Serializer)
static final Deserializer<Map<String, ?>> deserializer = Services.loadFirst(Deserializer)
static final Serializer<Map<String, ?>> serializer = Services.get(Serializer)
static final Deserializer<Map<String, ?>> deserializer = Services.get(Deserializer)

static byte[] ser(def value) {
ByteArrayOutputStream baos = new ByteArrayOutputStream(512)
Expand Down
4 changes: 1 addition & 3 deletions impl/src/test/groovy/io/jsonwebtoken/RFC7797Test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,9 @@ class RFC7797Test {
def claims = Jwts.claims().subject('me').build()

ByteArrayOutputStream out = new ByteArrayOutputStream()
Services.loadFirst(Serializer).serialize(claims, out)
Services.get(Serializer).serialize(claims, out)
byte[] content = out.toByteArray()

//byte[] content = Services.loadFirst(Serializer).serialize(claims)

String s = Jwts.builder().signWith(key).content(content).encodePayload(false).compact()

// But verify with 3 types of sources: string, byte array, and two different kinds of InputStreams:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ class DefaultJwtBuilderTest {
private DefaultJwtBuilder builder

private static byte[] serialize(Map<String, ?> map) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
ByteArrayOutputStream out = new ByteArrayOutputStream(512)
serializer.serialize(map, out)
return out.toByteArray()
}

private static Map<String, ?> deser(byte[] data) {
def reader = Streams.reader(data)
Map<String, ?> m = Services.loadFirst(Deserializer).deserialize(reader) as Map<String, ?>
Map<String, ?> m = Services.get(Deserializer).deserialize(reader) as Map<String, ?>
return m
}

Expand Down Expand Up @@ -749,7 +749,7 @@ class DefaultJwtBuilderTest {
// so we need to check the raw payload:
def encoded = new JwtTokenizer().tokenize(Streams.reader(jwt)).getPayload()
byte[] bytes = Decoders.BASE64URL.decode(encoded)
def claims = Services.loadFirst(Deserializer).deserialize(Streams.reader(bytes))
def claims = Services.get(Deserializer).deserialize(Streams.reader(bytes))

assertEquals two, claims.aud
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DefaultJwtParserTest {
}

private static byte[] serialize(Map<String, ?> map) {
def serializer = Services.loadFirst(Serializer)
def serializer = Services.get(Serializer)
ByteArrayOutputStream out = new ByteArrayOutputStream(512)
serializer.serialize(map, out)
return out.toByteArray()
Expand Down
2 changes: 1 addition & 1 deletion impl/src/test/groovy/io/jsonwebtoken/impl/RfcTests.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RfcTests {

static final Map<String, ?> jsonToMap(String json) {
Reader r = new CharSequenceReader(json)
Map<String, ?> m = Services.loadFirst(Deserializer).deserialize(r) as Map<String, ?>
Map<String, ?> m = Services.get(Deserializer).deserialize(r) as Map<String, ?>
return m
}

Expand Down
25 changes: 7 additions & 18 deletions impl/src/test/groovy/io/jsonwebtoken/impl/lang/ServicesTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,21 @@ import io.jsonwebtoken.impl.DefaultStubService
import org.junit.After
import org.junit.Test

import static org.junit.Assert.*
import static org.junit.Assert.assertEquals
import static org.junit.Assert.assertNotNull

class ServicesTest {

@Test
void testSuccessfulLoading() {
def factory = Services.loadFirst(StubService)
assertNotNull factory
assertEquals(DefaultStubService, factory.class)
def service = Services.get(StubService)
assertNotNull service
assertEquals(DefaultStubService, service.class)
}

@Test(expected = UnavailableImplementationException)
void testLoadFirstUnavailable() {
Services.loadFirst(NoService.class)
}

@Test
void testLoadAllAvailable() {
def list = Services.loadAll(StubService.class)
assertEquals 1, list.size()
assertTrue list[0] instanceof StubService
}

@Test(expected = UnavailableImplementationException)
void testLoadAllUnavailable() {
Services.loadAll(NoService.class)
void testLoadUnavailable() {
Services.get(NoService.class)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class RFC7518AppendixCTest {
}

private static final Map<String, ?> fromJson(String s) {
return Services.loadFirst(Deserializer).deserialize(new StringReader(s)) as Map<String, ?>
return Services.get(Deserializer).deserialize(new StringReader(s)) as Map<String, ?>
}

private static EcPrivateJwk readJwk(String json) {
Expand Down