Skip to content

Commit

Permalink
✅ Add unit tests for NetworkUtils (#1941)
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyanakuang committed Sep 22, 2024
1 parent b60ce2a commit d911f33
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 16 deletions.
6 changes: 6 additions & 0 deletions composeApp/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ tasks.withType<Test> {
showCauses = true
showStackTraces = true
}
jvmArgs(
"--add-opens",
"java.base/java.net=ALL-UNNAMED",
"--add-opens",
"java.base/java.lang.reflect=ALL-UNNAMED",
)
}

// region Work around temporary Compose bugs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ interface NetUtils {
): Boolean

fun getPreferredLocalIPAddress(): String?

fun clearProviderCache()
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ class ValueProvider<T> {
lastSuccessfulValue
}
}

fun clear() {
lastSuccessfulValue = null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,30 @@ object DesktopNetUtils : NetUtils {
private val preferredLocalIPAddress = ValueProvider<String?>()

// Get all potential local IP addresses
private fun getAllLocalAddresses(): Sequence<Pair<HostInfo, String>> {
val networkInterfaces = Collections.list(NetworkInterface.getNetworkInterfaces())

networkInterfaces.forEach { nic ->
logger.info { "Network interface: ${nic.name}" }
nic.interfaceAddresses.forEach { addr ->
logger.info { "\t\tInterface address: ${addr.address.hostAddress}" }
}
}

return networkInterfaces
fun getAllLocalAddresses(): Sequence<Pair<HostInfo, String>> {
return Collections.list(NetworkInterface.getNetworkInterfaces())
.asSequence()
.filter { it.isUp && !it.isLoopback && !it.isVirtual }
.flatMap { nic ->
nic.interfaceAddresses.asSequence().map { Pair(it, nic.name) }
}
.filter { (addr, _) ->
.filter { (addr, nicName) ->
val address = addr.address
if (address is Inet4Address) {
val hostAddress = address.hostAddress
hostAddress != null &&
!hostAddress.endsWith(".0") &&
!hostAddress.endsWith(".1") &&
!hostAddress.endsWith(".255")
val networkPrefixLength = addr.networkPrefixLength
val isLocalAddress =
hostAddress != null &&
!hostAddress.endsWith(".0") &&
!hostAddress.endsWith(".1") &&
!hostAddress.endsWith(".255")
logger.info {
"get local address, Network interface: $nicName " +
"address: $hostAddress networkPrefixLength: $networkPrefixLength"
}
isLocalAddress
} else {
logger.info { "Network interface: $nicName is not local address" }
false
}
}
Expand Down Expand Up @@ -122,4 +121,9 @@ object DesktopNetUtils : NetUtils {
}
}
}

override fun clearProviderCache() {
hostListProvider.clear()
preferredLocalIPAddress.clear()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package com.crosspaste.utils

import com.crosspaste.utils.DesktopNetUtils.getAllLocalAddresses
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.unmockkAll
import io.mockk.verify
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import java.net.Inet4Address
import java.net.InterfaceAddress
import java.net.NetworkInterface
import java.util.Collections
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNull

class NetworkUtilsTest {

@BeforeEach
fun setUp() {
mockkStatic(NetworkInterface::class)
}

@AfterEach
fun tearDown() {
unmockkAll()
getNetUtils().clearProviderCache()
}

@Test
fun testGetAllLocalAddresses() {
// Use MockK to mock NetworkInterface and related classes
mockkStatic(NetworkInterface::class)

val nic1 = mockk<NetworkInterface>()
val nic2 = mockk<NetworkInterface>()
val nic3 = mockk<NetworkInterface>()

every { nic1.isUp } returns true
every { nic1.isLoopback } returns false
every { nic1.isVirtual } returns false
every { nic1.name } returns "eth0"

every { nic2.isUp } returns true
every { nic2.isLoopback } returns false
every { nic2.isVirtual } returns false
every { nic2.name } returns "wlan0"

every { nic3.isUp } returns true
every { nic3.isLoopback } returns false
every { nic3.isVirtual } returns false
every { nic3.name } returns "null0"

val addr1 = mockk<InterfaceAddress>()
val addr2 = mockk<InterfaceAddress>()
val addr3 = mockk<InterfaceAddress>()
val inetAddr1 = mockk<Inet4Address>()
val inetAddr2 = mockk<Inet4Address>()

every { addr1.address } returns inetAddr1
every { addr2.address } returns inetAddr2
every { addr3.address } returns null
every { inetAddr1.hostAddress } returns "192.168.1.8"
every { inetAddr2.hostAddress } returns "10.0.0.5"
every { addr1.networkPrefixLength } returns 24
every { addr2.networkPrefixLength } returns 16
every { addr3.networkPrefixLength } returns 0

every { nic1.interfaceAddresses } returns listOf(addr1)
every { nic2.interfaceAddresses } returns listOf(addr2)
every { nic3.interfaceAddresses } returns listOf(addr3)

val networkInterfaces = Collections.enumeration(listOf(nic1, nic2, nic3))
every { NetworkInterface.getNetworkInterfaces() } returns networkInterfaces

// Execute the test
val result = getAllLocalAddresses().toList()

// Verify the results
assertEquals(2, result.size)
assertEquals("192.168.1.8", result[0].first.hostAddress)
assertEquals(24, result[0].first.networkPrefixLength)
assertEquals("eth0", result[0].second)
assertEquals("10.0.0.5", result[1].first.hostAddress)
assertEquals(16, result[1].first.networkPrefixLength)
assertEquals("wlan0", result[1].second)

// Verify that NetworkInterface.getNetworkInterfaces() was called
verify { NetworkInterface.getNetworkInterfaces() }

// Verify that the interface with null address was skipped
assertFalse(result.any { it.second == "null0" })
}

@Test
fun `getPreferredLocalIPAddress returns correct IP when valid interfaces exist`() {
val nic1 = mockNetworkInterface("eth0", "192.168.1.100", 24)
val nic2 = mockNetworkInterface("wlan0", "192.168.2.100", 24)

every { NetworkInterface.getNetworkInterfaces() } returns Collections.enumeration(listOf(nic1, nic2))

val result = DesktopNetUtils.getPreferredLocalIPAddress()

assertEquals("192.168.1.100", result)
}

@Test
fun `getPreferredLocalIPAddress returns null when no valid interfaces exist`() {
val nic = mockNetworkInterface("lo", "127.0.0.1", 8)

every { NetworkInterface.getNetworkInterfaces() } returns Collections.enumeration(listOf(nic))

val result = DesktopNetUtils.getPreferredLocalIPAddress()

assertNull(result)
}

@Test
fun `getPreferredLocalIPAddress prefers eth interfaces over others`() {
val nic1 = mockNetworkInterface("wlan0", "192.168.2.100", 24)
val nic2 = mockNetworkInterface("eth0", "192.168.1.100", 24)

every { NetworkInterface.getNetworkInterfaces() } returns Collections.enumeration(listOf(nic1, nic2))

val result = DesktopNetUtils.getPreferredLocalIPAddress()

assertEquals("192.168.1.100", result)
}

private fun mockNetworkInterface(
name: String,
ip: String,
prefixLength: Short,
): NetworkInterface {
return mockk<NetworkInterface>().apply {
every { isUp } returns true
every { isLoopback } returns false
every { isVirtual } returns false
every { this@apply.name } returns name
every { interfaceAddresses } returns
listOf(
mockk<InterfaceAddress>().apply {
every { address } returns
mockk<Inet4Address>().apply {
every { hostAddress } returns ip
}
every { networkPrefixLength } returns prefixLength
},
)
}
}
}

0 comments on commit d911f33

Please sign in to comment.