diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index a5f66ce5ea83..61113866c9f5 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -262,8 +262,10 @@ record CmdLineArgs( if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); } - if (dimensions <= 0) { - throw new IllegalArgumentException("dimensions must be a positive integer"); + if (dimensions <= 0 && dimensions != -1) { + throw new IllegalArgumentException( + "dimensions must be a positive integer or -1 for when dimension is available in the vector file" + ); } return new CmdLineArgs( docVectors, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index dcce1fd304b0..685f88372c70 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -200,7 +200,7 @@ public class KnnIndexTester { knnIndexer.numSegments(result); } } - if (cmdLineArgs.queryVectors() != null) { + if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) { KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs); knnSearcher.runSearch(result); } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java index f0c1648ae52d..40eb8424aeb1 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java @@ -64,7 +64,7 @@ class KnnIndexer { private final Path docsPath; private final Path indexPath; private final VectorEncoding vectorEncoding; - private final int dim; + private int dim; private final VectorSimilarityFunction similarityFunction; private final Codec codec; private final int numDocs; @@ -106,10 +106,6 @@ class KnnIndexer { iwc.setMaxFullFlushMergeWaitMillis(0); - FieldType fieldType = switch (vectorEncoding) { - case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction); - case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction); - }; iwc.setInfoStream(new PrintStreamInfoStream(System.out) { @Override public boolean isEnabled(String component) { @@ -137,7 +133,26 @@ class KnnIndexer { FileChannel in = FileChannel.open(docsPath) ) { long docsPathSizeInBytes = in.size(); - if (docsPathSizeInBytes % ((long) dim * vectorEncoding.byteSize) != 0) { + int offsetByteSize = 0; + if (dim == -1) { + offsetByteSize = 4; + ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + int bytesRead = Channels.readFromFileChannel(in, 0, preamble); + if (bytesRead < 4) { + throw new IllegalArgumentException( + "docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes + ); + } + dim = preamble.getInt(0); + if (dim <= 0) { + throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim); + } + } + FieldType fieldType = switch (vectorEncoding) { + case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction); + case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction); + }; + if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) { throw new IllegalArgumentException( "docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes ); @@ -150,7 +165,7 @@ class KnnIndexer { vectorEncoding.byteSize ); - VectorReader inReader = VectorReader.create(in, dim, vectorEncoding); + VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize); try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) { AtomicInteger numDocsIndexed = new AtomicInteger(); List> threads = new ArrayList<>(); @@ -271,21 +286,24 @@ class KnnIndexer { static class VectorReader { final float[] target; + final int offsetByteSize; final ByteBuffer bytes; final FileChannel input; long position; - static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) throws IOException { + static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int offsetByteSize) throws IOException { + // check if dim is set as preamble in the file: int bufferSize = dim * vectorEncoding.byteSize; - if (input.size() % ((long) dim * vectorEncoding.byteSize) != 0) { + if (input.size() % ((long) dim * vectorEncoding.byteSize + offsetByteSize) != 0) { throw new IllegalArgumentException( "vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size() ); } - return new VectorReader(input, dim, bufferSize); + return new VectorReader(input, dim, bufferSize, offsetByteSize); } - VectorReader(FileChannel input, int dim, int bufferSize) throws IOException { + VectorReader(FileChannel input, int dim, int bufferSize, int offsetByteSize) throws IOException { + this.offsetByteSize = offsetByteSize; this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN); this.input = input; this.target = new float[dim]; @@ -293,14 +311,14 @@ class KnnIndexer { } void reset() throws IOException { - position = 0; + position = offsetByteSize; input.position(position); } private void readNext() throws IOException { int bytesRead = Channels.readFromFileChannel(this.input, position, bytes); if (bytesRead < bytes.capacity()) { - position = 0; + position = offsetByteSize; bytes.position(0); // wrap around back to the start of the file if we hit the end: logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again"); @@ -312,7 +330,7 @@ class KnnIndexer { ); } } - position += bytesRead; + position += bytesRead + offsetByteSize; bytes.position(0); } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index ec90dd46ef5b..7dd6f2894a20 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -40,6 +40,7 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.MMapDirectory; +import org.elasticsearch.common.io.Channels; import org.elasticsearch.core.PathUtils; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -87,7 +88,7 @@ class KnnSearcher { private final int efSearch; private final int nProbe; private final KnnIndexTester.IndexType indexType; - private final int dim; + private int dim; private final VectorSimilarityFunction similarityFunction; private final VectorEncoding vectorEncoding; private final float overSamplingFactor; @@ -117,6 +118,7 @@ class KnnSearcher { TopDocs[] results = new TopDocs[numQueryVectors]; int[][] resultIds = new int[numQueryVectors][]; long elapsed, totalCpuTimeMS, totalVisited = 0; + int offsetByteSize = 0; try ( FileChannel input = FileChannel.open(queryPath); ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread")) @@ -128,7 +130,19 @@ class KnnSearcher { + " bytes, assuming vector count is " + (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize)) ); - KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding); + if (dim == -1) { + offsetByteSize = 4; + ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + int bytesRead = Channels.readFromFileChannel(input, 0, preamble); + if (bytesRead < 4) { + throw new IllegalArgumentException("queryPath \"" + queryPath + "\" does not contain a valid dims?"); + } + dim = preamble.getInt(0); + if (dim <= 0) { + throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim); + } + } + KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize); long startNS; try (MMapDirectory dir = new MMapDirectory(indexPath)) { try (DirectoryReader reader = DirectoryReader.open(dir)) { @@ -191,7 +205,7 @@ class KnnSearcher { } } logger.info("checking results"); - int[][] nn = getOrCalculateExactNN(); + int[][] nn = getOrCalculateExactNN(offsetByteSize); finalResults.avgRecall = checkResults(resultIds, nn, topK); finalResults.qps = (1000f * numQueryVectors) / elapsed; finalResults.avgLatency = (float) elapsed / numQueryVectors; @@ -200,7 +214,7 @@ class KnnSearcher { finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed; } - private int[][] getOrCalculateExactNN() throws IOException { + private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException { // look in working directory for cached nn file String hash = Integer.toString( Objects.hash( @@ -228,9 +242,9 @@ class KnnSearcher { // checking low-precision recall int[][] nn; if (vectorEncoding.equals(VectorEncoding.BYTE)) { - nn = computeExactNNByte(queryPath); + nn = computeExactNNByte(queryPath, vectorFileOffsetBytes); } else { - nn = computeExactNN(queryPath); + nn = computeExactNN(queryPath, vectorFileOffsetBytes); } writeExactNN(nn, nnPath); long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms @@ -356,12 +370,17 @@ class KnnSearcher { } } - private int[][] computeExactNN(Path queryPath) throws IOException { + private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException { int[][] result = new int[numQueryVectors][]; try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { List> tasks = new ArrayList<>(); try (FileChannel qIn = FileChannel.open(queryPath)) { - KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.FLOAT32); + KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create( + qIn, + dim, + VectorEncoding.FLOAT32, + vectorFileOffsetBytes + ); for (int i = 0; i < numQueryVectors; i++) { float[] queryVector = new float[dim]; queryReader.next(queryVector); @@ -373,12 +392,12 @@ class KnnSearcher { } } - private int[][] computeExactNNByte(Path queryPath) throws IOException { + private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException { int[][] result = new int[numQueryVectors][]; try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { List> tasks = new ArrayList<>(); try (FileChannel qIn = FileChannel.open(queryPath)) { - KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE); + KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE, vectorFileOffsetBytes); for (int i = 0; i < numQueryVectors; i++) { byte[] queryVector = new byte[dim]; queryReader.next(queryVector);