Skip to content

Commit

Permalink
Merge pull request #2012 from OneSignal/fix/service-provide-thread-sa…
Browse files Browse the repository at this point in the history
…fety

Make ServiceBuilder.getService thread safe
  • Loading branch information
jkasten2 committed Mar 1, 2024
2 parents 221a02e + 3932338 commit b8a2ba0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import com.onesignal.debug.internal.logging.Logging
class ServiceProvider(
registrations: List<ServiceRegistration<*>>,
) : IServiceProvider {
private var serviceMap: Map<Class<*>, List<ServiceRegistration<*>>>
private val serviceMap = mutableMapOf<Class<*>, MutableList<ServiceRegistration<*>>>()

init {
val serviceMap = mutableMapOf<Class<*>, MutableList<ServiceRegistration<*>>>()

// go through the registrations to create the service map for easier lookup post-build
for (reg in registrations) {
for (service in reg.services) {
Expand All @@ -23,8 +21,6 @@ class ServiceProvider(
}
}
}

this.serviceMap = serviceMap
}

internal inline fun <reified T : Any> hasService(): Boolean {
Expand All @@ -44,23 +40,27 @@ class ServiceProvider(
}

override fun <T> hasService(c: Class<T>): Boolean {
return serviceMap.containsKey(c)
synchronized(serviceMap) {
return serviceMap.containsKey(c)
}
}

override fun <T> getAllServices(c: Class<T>): List<T> {
val listOfServices: MutableList<T> = mutableListOf()
synchronized(serviceMap) {
val listOfServices: MutableList<T> = mutableListOf()

if (serviceMap.containsKey(c)) {
for (serviceReg in serviceMap!![c]!!) {
val service =
serviceReg.resolve(this) as T?
?: throw Exception("Could not instantiate service: $serviceReg")
if (serviceMap.containsKey(c)) {
for (serviceReg in serviceMap!![c]!!) {
val service =
serviceReg.resolve(this) as T?
?: throw Exception("Could not instantiate service: $serviceReg")

listOfServices.add(service)
listOfServices.add(service)
}
}
}

return listOfServices
return listOfServices
}
}

override fun <T> getService(c: Class<T>): T {
Expand All @@ -74,11 +74,10 @@ class ServiceProvider(
}

override fun <T> getServiceOrNull(c: Class<T>): T? {
Logging.debug("${indent}Retrieving service $c")
// indent += " "
val service = serviceMap[c]?.last()?.resolve(this) as T?
// indent = indent.substring(0, indent.length-2)
return service
synchronized(serviceMap) {
Logging.debug("${indent}Retrieving service $c")
return serviceMap[c]?.last()?.resolve(this) as T?
}
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.onesignal.common

import com.onesignal.common.services.ServiceBuilder
import com.onesignal.common.services.ServiceProvider
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.types.shouldBeSameInstanceAs
import java.util.concurrent.LinkedBlockingQueue

internal interface IMyTestInterface

internal class MySlowConstructorClass : IMyTestInterface {
init {
// NOTE: Keep these println calls, otherwise Kotlin optimizes
// something which cases the test not fail when it should.
println("MySlowConstructorClass BEFORE")
Thread.sleep(10)
println("MySlowConstructorClass AFTER")
}
}

class ServiceProviderTest : FunSpec({

fun setupServiceProviderWithSlowInitClass(): ServiceProvider {
val serviceBuilder = ServiceBuilder()
serviceBuilder.register<MySlowConstructorClass>().provides<IMyTestInterface>()
return serviceBuilder.build()
}

test("getService is thread safe") {
val services = setupServiceProviderWithSlowInitClass()

val queue = LinkedBlockingQueue<IMyTestInterface>()
Thread {
queue.add(services.getService<IMyTestInterface>())
}.start()
Thread {
queue.add(services.getService<IMyTestInterface>())
}.start()

val firstReference = queue.take()
val secondReference = queue.take()
firstReference shouldBeSameInstanceAs secondReference
}
})

0 comments on commit b8a2ba0

Please sign in to comment.