diff --git a/src/main/java/com/lewdev/probabilitylib/ProbabilityCollection.java b/src/main/java/com/lewdev/probabilitylib/ProbabilityCollection.java index 8cf2953..54447d3 100644 --- a/src/main/java/com/lewdev/probabilitylib/ProbabilityCollection.java +++ b/src/main/java/com/lewdev/probabilitylib/ProbabilityCollection.java @@ -2,6 +2,7 @@ import java.util.Comparator; import java.util.Iterator; +import java.util.Objects; import java.util.TreeSet; import java.util.concurrent.ThreadLocalRandom; @@ -22,7 +23,7 @@ * * * @author Lewys Davies - * @version 0.5 + * @version 0.6 * * @param Type of elements */ @@ -60,12 +61,15 @@ public boolean isEmpty() { /** * @param object * @return True if the collection contains the object, else False + * @throws IllegalArgumentException if object null */ public boolean contains(E object) { + if(object == null) { + throw new IllegalArgumentException("Cannot check if null object is contained in a collection"); + } + return this.collection.stream() - .filter(entry -> entry.getObject().equals(object)) - .findFirst() - .isPresent(); + .anyMatch(entry -> entry.getObject().equals(object)); } /** @@ -78,12 +82,24 @@ public Iterator> iterator() { /** * Add an object to this collection * - * @param object - * @param probability share + * @param object. Not null. + * @param probability share. Must be greater than 0. + * + * @throws IllegalArgumentException if object is null + * @throws IllegalArgumentException if probability <= 0 */ public void add(E object, int probability) { + if(object == null) { + throw new IllegalArgumentException("Cannot add null object"); + } + + if(probability <= 0) { + throw new IllegalArgumentException("Probability must be greater than 0"); + } + this.collection.add(new ProbabilitySetElement(object, probability)); this.totalProbability += probability; + this.updateIndexes(); } @@ -92,32 +108,46 @@ public void add(E object, int probability) { * * @param object * @return True if object was removed, else False. + * + * @throws IllegalArgumentException if object null */ public boolean remove(E object) { + if(object == null) { + throw new IllegalArgumentException("Cannot remove null object"); + } + Iterator> it = this.iterator(); - boolean removed = false; + boolean removed = it.hasNext(); while(it.hasNext()) { - ProbabilitySetElement element = it.next(); - if(element.getObject().equals(object)) { - removed = true; - this.totalProbability -= element.getProbability(); + ProbabilitySetElement entry = it.next(); + if(entry.getObject().equals(object)) { + this.totalProbability -= entry.getProbability(); it.remove(); } } this.updateIndexes(); + return removed; } /** - * @return Random object based on probability + * Get a random object from this collection, based on probability. + * + * @return Random object + * + * @throws IllegalStateException if this collection is empty */ public E get() { + if(this.isEmpty()) { + throw new IllegalStateException("Cannot get an element out of a empty set"); + } + ProbabilitySetElement toFind = new ProbabilitySetElement<>(null, 0); toFind.setIndex(ThreadLocalRandom.current().nextInt(1, this.totalProbability + 1)); - return this.collection.floor(toFind).getObject(); + return Objects.requireNonNull(this.collection.floor(toFind).getObject()); } /** @@ -131,7 +161,8 @@ public final int getTotalProbability() { * Calculate the size of all element's "block" of space: * i.e 1-5, 6-10, 11-14, 15, 16 * - * We then only need to store the start index of each element + * We then only need to store the start index of each element, + * as we make use of the TreeSet#floor */ private void updateIndexes() { int previousIndex = 0; @@ -182,12 +213,12 @@ public final int getProbability() { } // Used internally, see this class's documentation - protected final int getIndex() { + private final int getIndex() { return this.index; } // Used Internally, see this class's documentation - protected final int setIndex(int index) { + private final int setIndex(int index) { this.index = index; return this.index; } diff --git a/src/test/java/com/lewdev/probabilitylib/ProbabilityCollectionTest.java b/src/test/java/com/lewdev/probabilitylib/ProbabilityCollectionTest.java index cffc541..d4d6173 100644 --- a/src/test/java/com/lewdev/probabilitylib/ProbabilityCollectionTest.java +++ b/src/test/java/com/lewdev/probabilitylib/ProbabilityCollectionTest.java @@ -3,13 +3,14 @@ import static org.junit.jupiter.api.Assertions.*; import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; /** * @author Lewys Davies */ public class ProbabilityCollectionTest { - @RepeatedTest(value = 1000) + @RepeatedTest(value = 10_000) public void test_insert() { ProbabilityCollection collection = new ProbabilityCollection<>(); assertEquals(0, collection.size()); @@ -44,12 +45,12 @@ public void test_insert() { } } - @RepeatedTest(value = 1000) + @RepeatedTest(value = 10_000) public void test_remove() { ProbabilityCollection collection = new ProbabilityCollection<>(); - assertTrue(collection.size() == 0); + assertEquals(0, collection.size()); assertTrue(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 0); + assertEquals(0, collection.getTotalProbability()); String t1 = "Hello"; String t2 = "World"; @@ -59,33 +60,36 @@ public void test_remove() { collection.add(t2, 10); collection.add(t3, 10); - assertTrue(collection.size() == 3); + assertEquals(3, collection.size()); assertFalse(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 30); + assertEquals(30, collection.getTotalProbability()); - collection.remove(t2); + // Remove t2 + assertTrue(collection.remove(t2)); - assertTrue(collection.size() == 2); + assertEquals(2, collection.size()); assertFalse(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 20); + assertEquals(20, collection.getTotalProbability()); - collection.remove(t1); + // Remove t1 + assertTrue(collection.remove(t1)); - assertTrue(collection.size() == 1); + assertEquals(1, collection.size()); assertFalse(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 10); + assertEquals(10, collection.getTotalProbability()); - collection.remove(t3); + //Remove t3 + assertTrue(collection.remove(t3)); - assertTrue(collection.size() == 0); - assertTrue(collection.getTotalProbability() == 0); + assertEquals(0, collection.size()); assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); } - @RepeatedTest(value = 1000) + @RepeatedTest(value = 10_000) public void test_remove_duplicates() { ProbabilityCollection collection = new ProbabilityCollection<>(); - assertTrue(collection.size() == 0); + assertEquals(0, collection.size()); assertTrue(collection.isEmpty()); String t1 = "Hello"; @@ -104,27 +108,30 @@ public void test_remove_duplicates() { collection.add(t3, 10); } - assertTrue(collection.size() == 30); + assertEquals(30, collection.size()); assertFalse(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 300); + assertEquals(300, collection.getTotalProbability()); - collection.remove(t2); + //Remove t2 + assertTrue(collection.remove(t2)); - assertTrue(collection.size() == 20); + assertEquals(20, collection.size()); assertFalse(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 200); + assertEquals(200, collection.getTotalProbability()); - collection.remove(t1); + // Remove t1 + assertTrue(collection.remove(t1)); - assertTrue(collection.size() == 10); + assertEquals(10, collection.size()); assertFalse(collection.isEmpty()); + assertEquals(100, collection.getTotalProbability()); - assertTrue(collection.getTotalProbability() == 100); - collection.remove(t3); + //Remove t3 + assertTrue(collection.remove(t3)); - assertTrue(collection.size() == 0); + assertEquals(0, collection.size()); assertTrue(collection.isEmpty()); - assertTrue(collection.getTotalProbability() == 0); + assertEquals(0, collection.getTotalProbability()); } @RepeatedTest(1_000_000) @@ -159,10 +166,82 @@ public void test_probability() { double bResult = b / (double) totalGets * 100; double cResult = c / (double) totalGets * 100; - double acceptableDeviation = 1; + double acceptableDeviation = 1; // % assertTrue(Math.abs(aProb - aResult) <= acceptableDeviation); assertTrue(Math.abs(bProb - bResult) <= acceptableDeviation); assertTrue(Math.abs(cProb - cResult) <= acceptableDeviation); } + + @RepeatedTest(1_000_000) + public void test_get_never_null() { + ProbabilityCollection collection = new ProbabilityCollection<>(); + // Tests get will never return null + // Just one smallest element get, must not return null + collection.add("A", 1); + assertNotNull(collection.get()); + + // Reset state + collection.remove("A"); + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + + // Just one large element, must not return null + collection.add("A", 5_000_000); + assertNotNull(collection.get()); + } + + @Test + public void test_Errors() { + ProbabilityCollection collection = new ProbabilityCollection<>(); + + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); + + // Cannot get from empty collection + assertThrows(IllegalStateException.class, () -> { + collection.get(); + }); + + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); + + // Cannot add null object + assertThrows(IllegalArgumentException.class, () -> { + collection.add(null, 1); + }); + + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); + + // Cannot add prob 0 + assertThrows(IllegalArgumentException.class, () -> { + collection.add("A", 0); + }); + + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); + + // Cannot remove null + assertThrows(IllegalArgumentException.class, () -> { + collection.remove(null); + }); + + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); + + // Cannot contains null + assertThrows(IllegalArgumentException.class, () -> { + collection.contains(null); + }); + + assertEquals(0, collection.size()); + assertTrue(collection.isEmpty()); + assertEquals(0, collection.getTotalProbability()); + } }