Allow reading vectors where dim is in the file (#130138)

This allows configuration to have a `-1` dim to read files that have the
`dim` in the file.

Additionally, allows setting numQuerys to `0` to skip the search phase
easily.
This commit is contained in:
Benjamin Trent 2025-06-26 18:38:21 -04:00 committed by GitHub
parent f3c5438799
commit f478f849e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 66 additions and 27 deletions

View file

@ -262,8 +262,10 @@ record CmdLineArgs(
if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided");
}
if (dimensions <= 0) {
throw new IllegalArgumentException("dimensions must be a positive integer");
if (dimensions <= 0 && dimensions != -1) {
throw new IllegalArgumentException(
"dimensions must be a positive integer or -1 for when dimension is available in the vector file"
);
}
return new CmdLineArgs(
docVectors,

View file

@ -200,7 +200,7 @@ public class KnnIndexTester {
knnIndexer.numSegments(result);
}
}
if (cmdLineArgs.queryVectors() != null) {
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
knnSearcher.runSearch(result);
}

View file

@ -64,7 +64,7 @@ class KnnIndexer {
private final Path docsPath;
private final Path indexPath;
private final VectorEncoding vectorEncoding;
private final int dim;
private int dim;
private final VectorSimilarityFunction similarityFunction;
private final Codec codec;
private final int numDocs;
@ -106,10 +106,6 @@ class KnnIndexer {
iwc.setMaxFullFlushMergeWaitMillis(0);
FieldType fieldType = switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
@Override
public boolean isEnabled(String component) {
@ -137,7 +133,26 @@ class KnnIndexer {
FileChannel in = FileChannel.open(docsPath)
) {
long docsPathSizeInBytes = in.size();
if (docsPathSizeInBytes % ((long) dim * vectorEncoding.byteSize) != 0) {
int offsetByteSize = 0;
if (dim == -1) {
offsetByteSize = 4;
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
if (bytesRead < 4) {
throw new IllegalArgumentException(
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
);
}
dim = preamble.getInt(0);
if (dim <= 0) {
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
}
}
FieldType fieldType = switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
throw new IllegalArgumentException(
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
);
@ -150,7 +165,7 @@ class KnnIndexer {
vectorEncoding.byteSize
);
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding);
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
AtomicInteger numDocsIndexed = new AtomicInteger();
List<Future<?>> threads = new ArrayList<>();
@ -271,21 +286,24 @@ class KnnIndexer {
static class VectorReader {
final float[] target;
final int offsetByteSize;
final ByteBuffer bytes;
final FileChannel input;
long position;
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) throws IOException {
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int offsetByteSize) throws IOException {
// check if dim is set as preamble in the file:
int bufferSize = dim * vectorEncoding.byteSize;
if (input.size() % ((long) dim * vectorEncoding.byteSize) != 0) {
if (input.size() % ((long) dim * vectorEncoding.byteSize + offsetByteSize) != 0) {
throw new IllegalArgumentException(
"vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size()
);
}
return new VectorReader(input, dim, bufferSize);
return new VectorReader(input, dim, bufferSize, offsetByteSize);
}
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
VectorReader(FileChannel input, int dim, int bufferSize, int offsetByteSize) throws IOException {
this.offsetByteSize = offsetByteSize;
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
this.input = input;
this.target = new float[dim];
@ -293,14 +311,14 @@ class KnnIndexer {
}
void reset() throws IOException {
position = 0;
position = offsetByteSize;
input.position(position);
}
private void readNext() throws IOException {
int bytesRead = Channels.readFromFileChannel(this.input, position, bytes);
if (bytesRead < bytes.capacity()) {
position = 0;
position = offsetByteSize;
bytes.position(0);
// wrap around back to the start of the file if we hit the end:
logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again");
@ -312,7 +330,7 @@ class KnnIndexer {
);
}
}
position += bytesRead;
position += bytesRead + offsetByteSize;
bytes.position(0);
}

View file

@ -40,6 +40,7 @@ import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.MMapDirectory;
import org.elasticsearch.common.io.Channels;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.search.profile.query.QueryProfiler;
@ -87,7 +88,7 @@ class KnnSearcher {
private final int efSearch;
private final int nProbe;
private final KnnIndexTester.IndexType indexType;
private final int dim;
private int dim;
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
private final float overSamplingFactor;
@ -117,6 +118,7 @@ class KnnSearcher {
TopDocs[] results = new TopDocs[numQueryVectors];
int[][] resultIds = new int[numQueryVectors][];
long elapsed, totalCpuTimeMS, totalVisited = 0;
int offsetByteSize = 0;
try (
FileChannel input = FileChannel.open(queryPath);
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
@ -128,7 +130,19 @@ class KnnSearcher {
+ " bytes, assuming vector count is "
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
);
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding);
if (dim == -1) {
offsetByteSize = 4;
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
int bytesRead = Channels.readFromFileChannel(input, 0, preamble);
if (bytesRead < 4) {
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" does not contain a valid dims?");
}
dim = preamble.getInt(0);
if (dim <= 0) {
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim);
}
}
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
long startNS;
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
try (DirectoryReader reader = DirectoryReader.open(dir)) {
@ -191,7 +205,7 @@ class KnnSearcher {
}
}
logger.info("checking results");
int[][] nn = getOrCalculateExactNN();
int[][] nn = getOrCalculateExactNN(offsetByteSize);
finalResults.avgRecall = checkResults(resultIds, nn, topK);
finalResults.qps = (1000f * numQueryVectors) / elapsed;
finalResults.avgLatency = (float) elapsed / numQueryVectors;
@ -200,7 +214,7 @@ class KnnSearcher {
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
}
private int[][] getOrCalculateExactNN() throws IOException {
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
// look in working directory for cached nn file
String hash = Integer.toString(
Objects.hash(
@ -228,9 +242,9 @@ class KnnSearcher {
// checking low-precision recall
int[][] nn;
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
nn = computeExactNNByte(queryPath);
nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
} else {
nn = computeExactNN(queryPath);
nn = computeExactNN(queryPath, vectorFileOffsetBytes);
}
writeExactNN(nn, nnPath);
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
@ -356,12 +370,17 @@ class KnnSearcher {
}
}
private int[][] computeExactNN(Path queryPath) throws IOException {
private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
try (FileChannel qIn = FileChannel.open(queryPath)) {
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.FLOAT32);
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(
qIn,
dim,
VectorEncoding.FLOAT32,
vectorFileOffsetBytes
);
for (int i = 0; i < numQueryVectors; i++) {
float[] queryVector = new float[dim];
queryReader.next(queryVector);
@ -373,12 +392,12 @@ class KnnSearcher {
}
}
private int[][] computeExactNNByte(Path queryPath) throws IOException {
private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
try (FileChannel qIn = FileChannel.open(queryPath)) {
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE);
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE, vectorFileOffsetBytes);
for (int i = 0; i < numQueryVectors; i++) {
byte[] queryVector = new byte[dim];
queryReader.next(queryVector);