diff --git a/app/src/main/java/org/apache/roller/weblogger/util/IPBanList.java b/app/src/main/java/org/apache/roller/weblogger/util/IPBanList.java index 4fd3eb12cf..2c9df32f8c 100644 --- a/app/src/main/java/org/apache/roller/weblogger/util/IPBanList.java +++ b/app/src/main/java/org/apache/roller/weblogger/util/IPBanList.java @@ -22,8 +22,10 @@ import java.io.FileReader; import java.io.FileWriter; import java.io.PrintWriter; -import java.util.HashSet; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.roller.weblogger.config.WebloggerConfig; @@ -38,10 +40,10 @@ */ public final class IPBanList { - private static Log log = LogFactory.getLog(IPBanList.class); + private static final Log log = LogFactory.getLog(IPBanList.class); // set of ips that are banned, use a set to ensure uniqueness - private Set bannedIps = new HashSet(); + private volatile Set bannedIps = newThreadSafeSet(); // file listing the ips that are banned private ModifiedFile bannedIpsFile = null; @@ -51,17 +53,17 @@ public final class IPBanList { static { - instance = new IPBanList(); + instance = new IPBanList(() -> WebloggerConfig.getProperty("ipbanlist.file")); } - // private because we are a singleton - private IPBanList() { + // package-private for unit tests + IPBanList(Supplier banIpsFilePathSupplier) { log.debug("INIT"); // load up set of denied ips - String banIpsFilePath = WebloggerConfig.getProperty("ipbanlist.file"); + String banIpsFilePath = banIpsFilePathSupplier.get(); if(banIpsFilePath != null) { ModifiedFile banIpsFile = new ModifiedFile(banIpsFilePath); @@ -82,7 +84,7 @@ public static IPBanList getInstance() { public boolean isBanned(String ip) { // update the banned ips list if needed - this.loadBannedIpsIfNeeded(false); + this.loadBannedIpsIfNeeded(); if(ip != null) { return this.bannedIps.contains(ip); @@ -99,7 +101,7 @@ public void addBannedIp(String ip) { } // update the banned ips list if needed - this.loadBannedIpsIfNeeded(false); + this.loadBannedIpsIfNeeded(); if(!this.bannedIps.contains(ip) && (bannedIpsFile != null && bannedIpsFile.canWrite())) { @@ -127,10 +129,10 @@ public void addBannedIp(String ip) { /** * Check if the banned ips file has changed and needs to be reloaded. */ - private void loadBannedIpsIfNeeded(boolean forceLoad) { + private void loadBannedIpsIfNeeded() { if(bannedIpsFile != null && - (bannedIpsFile.hasChanged() || forceLoad)) { + (bannedIpsFile.hasChanged())) { // need to reload this.loadBannedIps(); @@ -148,7 +150,7 @@ private synchronized void loadBannedIps() { // TODO: optimize this try (BufferedReader in = new BufferedReader(new FileReader(this.bannedIpsFile))) { - HashSet newBannedIpList = new HashSet(); + Set newBannedIpList = newThreadSafeSet(); String ip = null; while((ip = in.readLine()) != null) { @@ -170,7 +172,7 @@ private synchronized void loadBannedIps() { // a simple extension to the File class which tracks if the file has // changed since the last time we checked - private class ModifiedFile extends java.io.File { + private static class ModifiedFile extends java.io.File { private long myLastModified = 0; @@ -189,4 +191,7 @@ public void clearChanged() { } } + private static Set newThreadSafeSet() { + return ConcurrentHashMap.newKeySet(); + } } diff --git a/app/src/test/java/org/apache/roller/weblogger/util/IPBanListTest.java b/app/src/test/java/org/apache/roller/weblogger/util/IPBanListTest.java new file mode 100644 index 0000000000..9d0a97a1d3 --- /dev/null +++ b/app/src/test/java/org/apache/roller/weblogger/util/IPBanListTest.java @@ -0,0 +1,93 @@ +package org.apache.roller.weblogger.util; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class IPBanListTest { + + @TempDir + Path tmpDir; + Path ipBanList; + IPBanList sut; + + @BeforeEach + void setUp() throws IOException { + ipBanList = tmpDir.resolve("ipbanlist.txt"); + Files.createFile(ipBanList); + sut = new IPBanList(() -> ipBanList.toAbsolutePath().toString()); + } + + @Test + @DisplayName("addBanned() adds the given IP address to the file") + void addBannedAddsToFile() { + sut.addBannedIp("10.0.0.1"); + + List ipBanList = readIpBanList(); + assertTrue(ipBanList.contains("10.0.0.1")); + assertEquals(1, ipBanList.size()); + } + + @Test + @DisplayName("addBanned() ignores nulls") + void addBannedIgnoresNulls() { + sut.addBannedIp(null); + + assertTrue(readIpBanList().isEmpty()); + } + + @Test + @DisplayName("isBanned() returns true if the given IP address is banned") + void isBanned() { + sut.addBannedIp("10.0.0.1"); + + assertTrue(sut.isBanned("10.0.0.1")); + } + + @Test + @DisplayName("isBanned() returns false if the given IP address it not banned") + void isBanned2() { + assertFalse(sut.isBanned("10.0.0.1")); + } + + @Test + @DisplayName("isBanned() returns false if the given IP address is null") + void isBanned3() { + assertFalse(sut.isBanned(null)); + } + + @Test + @DisplayName("isBanned() reads the file if needed") + void isBanned4() { + writeIpBanList("10.0.0.1"); + + assertTrue(sut.isBanned("10.0.0.1")); + } + + private void writeIpBanList(String ipAddress) { + try { + Files.writeString(ipBanList, ipAddress); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private List readIpBanList() { + try { + return Files.readAllLines(ipBanList); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +}