diff --git a/jvector-native/pom.xml b/jvector-native/pom.xml index e30abd9b2..db3fbf7b5 100644 --- a/jvector-native/pom.xml +++ b/jvector-native/pom.xml @@ -35,6 +35,14 @@ + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + false + + diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java new file mode 100644 index 000000000..9f4208cd3 --- /dev/null +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java @@ -0,0 +1,140 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.disk; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.lang.foreign.ValueLayout.OfFloat; +import java.lang.foreign.ValueLayout.OfInt; +import java.lang.foreign.ValueLayout.OfLong; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; + +/** + * {@link MemorySegment} based implementation of RandomAccessReader. + * MemorySegmentReader doesn't have 2GB file size limitation of {@link SimpleMappedReader}. + */ +public class MemorySegmentReader implements RandomAccessReader { + + private static final OfInt intLayout = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); + private static final OfFloat floatLayout = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); + private static final OfLong longLayout = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); + + private final Arena arena; + private final MemorySegment memory; + private long position = 0; + + public MemorySegmentReader(Path path) throws IOException { + arena = Arena.ofShared(); + try (var ch = FileChannel.open(path, StandardOpenOption.READ)) { + memory = ch.map(MapMode.READ_ONLY, 0L, ch.size(), arena); + } catch (Exception e) { + arena.close(); + throw e; + } + } + + private MemorySegmentReader(Arena arena, MemorySegment memory) { + this.arena = arena; + this.memory = memory; + } + + @Override + public void seek(long offset) { + this.position = offset; + } + + @Override + public long getPosition() { + return position; + } + + @Override + public void readFully(float[] buffer) { + MemorySegment.copy(memory, floatLayout, position, buffer, 0, buffer.length); + position += buffer.length * 4L; + } + + @Override + public void readFully(byte[] b) { + MemorySegment.copy(memory, ValueLayout.JAVA_BYTE, position, b, 0, b.length); + position += b.length; + } + + @Override + public void readFully(ByteBuffer buffer) { + var remaining = buffer.remaining(); + var slice = memory.asSlice(position, remaining).asByteBuffer(); + buffer.put(slice); + position += remaining; + } + + @Override + public void readFully(long[] vector) { + MemorySegment.copy(memory, longLayout, position, vector, 0, vector.length); + position += vector.length * 8L; + } + + @Override + public int readInt() { + var k = memory.get(intLayout, position); + position += 4; + return k; + } + + @Override + public float readFloat() { + var f = memory.get(floatLayout, position); + position += 4; + return f; + } + + @Override + public void read(int[] ints, int offset, int count) { + MemorySegment.copy(memory, intLayout, position, ints, offset, count); + position += count * 4L; + } + + @Override + public void read(float[] floats, int offset, int count) { + MemorySegment.copy(memory, floatLayout, position, floats, offset, count); + position += count * 4L; + } + + /** + * Loads the contents of the mapped segment into physical memory. + * This is a best-effort mechanism. + */ + public void loadMemory() { + memory.load(); + } + + @Override + public void close() { + arena.close(); + } + + public MemorySegmentReader duplicate() { + return new MemorySegmentReader(arena, memory); + } +} diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReaderSupplier.java b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReaderSupplier.java new file mode 100644 index 000000000..2334c00db --- /dev/null +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReaderSupplier.java @@ -0,0 +1,37 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.disk; + +import java.io.IOException; +import java.nio.file.Path; + +public class MemorySegmentReaderSupplier implements ReaderSupplier { + private final MemorySegmentReader reader; + + public MemorySegmentReaderSupplier(Path path) throws IOException { + reader = new MemorySegmentReader(path); + } + + @Override + public RandomAccessReader get() { + return reader.duplicate(); + } + + @Override + public void close() { + reader.close(); + } +} diff --git a/jvector-native/src/test/java/io/github/jbellis/jvector/disk/MemorySegmentReaderTest.java b/jvector-native/src/test/java/io/github/jbellis/jvector/disk/MemorySegmentReaderTest.java new file mode 100644 index 000000000..e57ed6522 --- /dev/null +++ b/jvector-native/src/test/java/io/github/jbellis/jvector/disk/MemorySegmentReaderTest.java @@ -0,0 +1,148 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.disk; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class MemorySegmentReaderTest extends RandomizedTest { + + private Path tempFile; + + @Before + public void setup() throws IOException { + tempFile = Files.createTempFile(getClass().getSimpleName(), ".data"); + + try (var out = new DataOutputStream(new FileOutputStream(tempFile.toFile()))) { + out.write(new byte[] {1, 2, 3, 4, 5, 6, 7}); + for (int i = 0; i < 5; i++) { + out.writeInt((i + 1) * 19); + } + for (int i = 0; i < 5; i++) { + out.writeLong((i + 1) * 19L); + } + for (int i = 0; i < 5; i++) { + out.writeFloat((i + 1) * 19); + } + } + } + + @After + public void tearDown() throws IOException { + Files.deleteIfExists(tempFile); + } + + @Test + public void testReader() throws Exception { + try (var r = new MemorySegmentReader(tempFile)) { + verifyReader(r); + + // read 2nd time from beginning + verifyReader(r); + } + } + + @Test + public void testReaderDuplicate() throws Exception { + try (var r = new MemorySegmentReader(tempFile)) { + for (int i = 0; i < 3; i++) { + var r2 = r.duplicate(); + verifyReader(r2); + } + } + } + + @Test + public void testReaderClose() throws Exception { + var r = new MemorySegmentReader(tempFile); + var r2 = r.duplicate(); + + r.close(); + + try { + r.readInt(); + fail("Should have thrown an exception"); + } catch (IllegalStateException _) { + } + + try { + r2.readInt(); + fail("Should have thrown an exception"); + } catch (IllegalStateException _) { + } + } + + private void verifyReader(MemorySegmentReader r) { + r.seek(0); + var bytes = new byte[7]; + r.readFully(bytes); + for (int i = 0; i < bytes.length; i++) { + assertEquals(i + 1, bytes[i]); + } + + r.seek(0); + var buff = ByteBuffer.allocate(6); + r.readFully(buff); + for (int i = 0; i < buff.remaining(); i++) { + assertEquals(i + 1, buff.get(i)); + } + + r.seek(7); + assertEquals(19, r.readInt()); + + r.seek(7); + var ints = new int[5]; + r.read(ints, 0, ints.length); + for (int i = 0; i < ints.length; i++) { + var k = ints[i]; + assertEquals((i + 1) * 19, k); + } + + r.seek(7 + (4 * 5)); + var longs = new long[5]; + r.readFully(longs); + for (int i = 0; i < longs.length; i++) { + var l = longs[i]; + assertEquals((i + 1) * 19, l); + } + + r.seek(7 + (4 * 5) + (8 * 5)); + assertEquals(19, r.readFloat(), 0.01); + + r.seek(7 + (4 * 5) + (8 * 5)); + var floats = new float[5]; + r.readFully(floats); + for (int i = 0; i < floats.length; i++) { + var f = floats[i]; + assertEquals((i + 1) * 19f, f, 0.01); + } + } +}