Allow reading vectors where dim is in the file (#130138)

This allows configuration to have a `-1` dim to read files that have the
`dim` in the file.

Additionally, allows setting numQuerys to `0` to skip the search phase
easily.
This commit is contained in:
Benjamin Trent 2025-06-26 18:38:21 -04:00 committed by GitHub
parent f3c5438799
commit f478f849e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 66 additions and 27 deletions

View file

@ -262,8 +262,10 @@ record CmdLineArgs(
if (docVectors == null) { if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided"); throw new IllegalArgumentException("Document vectors path must be provided");
} }
if (dimensions <= 0) { if (dimensions <= 0 && dimensions != -1) {
throw new IllegalArgumentException("dimensions must be a positive integer"); throw new IllegalArgumentException(
"dimensions must be a positive integer or -1 for when dimension is available in the vector file"
);
} }
return new CmdLineArgs( return new CmdLineArgs(
docVectors, docVectors,

View file

@ -200,7 +200,7 @@ public class KnnIndexTester {
knnIndexer.numSegments(result); knnIndexer.numSegments(result);
} }
} }
if (cmdLineArgs.queryVectors() != null) { if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs); KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
knnSearcher.runSearch(result); knnSearcher.runSearch(result);
} }

View file

@ -64,7 +64,7 @@ class KnnIndexer {
private final Path docsPath; private final Path docsPath;
private final Path indexPath; private final Path indexPath;
private final VectorEncoding vectorEncoding; private final VectorEncoding vectorEncoding;
private final int dim; private int dim;
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final Codec codec; private final Codec codec;
private final int numDocs; private final int numDocs;
@ -106,10 +106,6 @@ class KnnIndexer {
iwc.setMaxFullFlushMergeWaitMillis(0); iwc.setMaxFullFlushMergeWaitMillis(0);
FieldType fieldType = switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
iwc.setInfoStream(new PrintStreamInfoStream(System.out) { iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
@Override @Override
public boolean isEnabled(String component) { public boolean isEnabled(String component) {
@ -137,7 +133,26 @@ class KnnIndexer {
FileChannel in = FileChannel.open(docsPath) FileChannel in = FileChannel.open(docsPath)
) { ) {
long docsPathSizeInBytes = in.size(); long docsPathSizeInBytes = in.size();
if (docsPathSizeInBytes % ((long) dim * vectorEncoding.byteSize) != 0) { int offsetByteSize = 0;
if (dim == -1) {
offsetByteSize = 4;
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
if (bytesRead < 4) {
throw new IllegalArgumentException(
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
);
}
dim = preamble.getInt(0);
if (dim <= 0) {
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
}
}
FieldType fieldType = switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes "docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
); );
@ -150,7 +165,7 @@ class KnnIndexer {
vectorEncoding.byteSize vectorEncoding.byteSize
); );
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding); VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) { try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
AtomicInteger numDocsIndexed = new AtomicInteger(); AtomicInteger numDocsIndexed = new AtomicInteger();
List<Future<?>> threads = new ArrayList<>(); List<Future<?>> threads = new ArrayList<>();
@ -271,21 +286,24 @@ class KnnIndexer {
static class VectorReader { static class VectorReader {
final float[] target; final float[] target;
final int offsetByteSize;
final ByteBuffer bytes; final ByteBuffer bytes;
final FileChannel input; final FileChannel input;
long position; long position;
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) throws IOException { static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int offsetByteSize) throws IOException {
// check if dim is set as preamble in the file:
int bufferSize = dim * vectorEncoding.byteSize; int bufferSize = dim * vectorEncoding.byteSize;
if (input.size() % ((long) dim * vectorEncoding.byteSize) != 0) { if (input.size() % ((long) dim * vectorEncoding.byteSize + offsetByteSize) != 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size() "vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size()
); );
} }
return new VectorReader(input, dim, bufferSize); return new VectorReader(input, dim, bufferSize, offsetByteSize);
} }
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException { VectorReader(FileChannel input, int dim, int bufferSize, int offsetByteSize) throws IOException {
this.offsetByteSize = offsetByteSize;
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN); this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
this.input = input; this.input = input;
this.target = new float[dim]; this.target = new float[dim];
@ -293,14 +311,14 @@ class KnnIndexer {
} }
void reset() throws IOException { void reset() throws IOException {
position = 0; position = offsetByteSize;
input.position(position); input.position(position);
} }
private void readNext() throws IOException { private void readNext() throws IOException {
int bytesRead = Channels.readFromFileChannel(this.input, position, bytes); int bytesRead = Channels.readFromFileChannel(this.input, position, bytes);
if (bytesRead < bytes.capacity()) { if (bytesRead < bytes.capacity()) {
position = 0; position = offsetByteSize;
bytes.position(0); bytes.position(0);
// wrap around back to the start of the file if we hit the end: // wrap around back to the start of the file if we hit the end:
logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again"); logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again");
@ -312,7 +330,7 @@ class KnnIndexer {
); );
} }
} }
position += bytesRead; position += bytesRead + offsetByteSize;
bytes.position(0); bytes.position(0);
} }

View file

@ -40,6 +40,7 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.store.MMapDirectory;
import org.elasticsearch.common.io.Channels;
import org.elasticsearch.core.PathUtils; import org.elasticsearch.core.PathUtils;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.search.profile.query.QueryProfiler;
@ -87,7 +88,7 @@ class KnnSearcher {
private final int efSearch; private final int efSearch;
private final int nProbe; private final int nProbe;
private final KnnIndexTester.IndexType indexType; private final KnnIndexTester.IndexType indexType;
private final int dim; private int dim;
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding; private final VectorEncoding vectorEncoding;
private final float overSamplingFactor; private final float overSamplingFactor;
@ -117,6 +118,7 @@ class KnnSearcher {
TopDocs[] results = new TopDocs[numQueryVectors]; TopDocs[] results = new TopDocs[numQueryVectors];
int[][] resultIds = new int[numQueryVectors][]; int[][] resultIds = new int[numQueryVectors][];
long elapsed, totalCpuTimeMS, totalVisited = 0; long elapsed, totalCpuTimeMS, totalVisited = 0;
int offsetByteSize = 0;
try ( try (
FileChannel input = FileChannel.open(queryPath); FileChannel input = FileChannel.open(queryPath);
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread")) ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
@ -128,7 +130,19 @@ class KnnSearcher {
+ " bytes, assuming vector count is " + " bytes, assuming vector count is "
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize)) + (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
); );
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding); if (dim == -1) {
offsetByteSize = 4;
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
int bytesRead = Channels.readFromFileChannel(input, 0, preamble);
if (bytesRead < 4) {
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" does not contain a valid dims?");
}
dim = preamble.getInt(0);
if (dim <= 0) {
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim);
}
}
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
long startNS; long startNS;
try (MMapDirectory dir = new MMapDirectory(indexPath)) { try (MMapDirectory dir = new MMapDirectory(indexPath)) {
try (DirectoryReader reader = DirectoryReader.open(dir)) { try (DirectoryReader reader = DirectoryReader.open(dir)) {
@ -191,7 +205,7 @@ class KnnSearcher {
} }
} }
logger.info("checking results"); logger.info("checking results");
int[][] nn = getOrCalculateExactNN(); int[][] nn = getOrCalculateExactNN(offsetByteSize);
finalResults.avgRecall = checkResults(resultIds, nn, topK); finalResults.avgRecall = checkResults(resultIds, nn, topK);
finalResults.qps = (1000f * numQueryVectors) / elapsed; finalResults.qps = (1000f * numQueryVectors) / elapsed;
finalResults.avgLatency = (float) elapsed / numQueryVectors; finalResults.avgLatency = (float) elapsed / numQueryVectors;
@ -200,7 +214,7 @@ class KnnSearcher {
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed; finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
} }
private int[][] getOrCalculateExactNN() throws IOException { private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
// look in working directory for cached nn file // look in working directory for cached nn file
String hash = Integer.toString( String hash = Integer.toString(
Objects.hash( Objects.hash(
@ -228,9 +242,9 @@ class KnnSearcher {
// checking low-precision recall // checking low-precision recall
int[][] nn; int[][] nn;
if (vectorEncoding.equals(VectorEncoding.BYTE)) { if (vectorEncoding.equals(VectorEncoding.BYTE)) {
nn = computeExactNNByte(queryPath); nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
} else { } else {
nn = computeExactNN(queryPath); nn = computeExactNN(queryPath, vectorFileOffsetBytes);
} }
writeExactNN(nn, nnPath); writeExactNN(nn, nnPath);
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
@ -356,12 +370,17 @@ class KnnSearcher {
} }
} }
private int[][] computeExactNN(Path queryPath) throws IOException { private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][]; int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
try (FileChannel qIn = FileChannel.open(queryPath)) { try (FileChannel qIn = FileChannel.open(queryPath)) {
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.FLOAT32); KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(
qIn,
dim,
VectorEncoding.FLOAT32,
vectorFileOffsetBytes
);
for (int i = 0; i < numQueryVectors; i++) { for (int i = 0; i < numQueryVectors; i++) {
float[] queryVector = new float[dim]; float[] queryVector = new float[dim];
queryReader.next(queryVector); queryReader.next(queryVector);
@ -373,12 +392,12 @@ class KnnSearcher {
} }
} }
private int[][] computeExactNNByte(Path queryPath) throws IOException { private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][]; int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
try (FileChannel qIn = FileChannel.open(queryPath)) { try (FileChannel qIn = FileChannel.open(queryPath)) {
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE); KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE, vectorFileOffsetBytes);
for (int i = 0; i < numQueryVectors; i++) { for (int i = 0; i < numQueryVectors; i++) {
byte[] queryVector = new byte[dim]; byte[] queryVector = new byte[dim];
queryReader.next(queryVector); queryReader.next(queryVector);