mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 01:22:26 -04:00
Allow reading vectors where dim is in the file (#130138)
This allows configuration to have a `-1` dim to read files that have the `dim` in the file. Additionally, allows setting numQuerys to `0` to skip the search phase easily.
This commit is contained in:
parent
f3c5438799
commit
f478f849e3
4 changed files with 66 additions and 27 deletions
|
@ -262,8 +262,10 @@ record CmdLineArgs(
|
||||||
if (docVectors == null) {
|
if (docVectors == null) {
|
||||||
throw new IllegalArgumentException("Document vectors path must be provided");
|
throw new IllegalArgumentException("Document vectors path must be provided");
|
||||||
}
|
}
|
||||||
if (dimensions <= 0) {
|
if (dimensions <= 0 && dimensions != -1) {
|
||||||
throw new IllegalArgumentException("dimensions must be a positive integer");
|
throw new IllegalArgumentException(
|
||||||
|
"dimensions must be a positive integer or -1 for when dimension is available in the vector file"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
return new CmdLineArgs(
|
return new CmdLineArgs(
|
||||||
docVectors,
|
docVectors,
|
||||||
|
|
|
@ -200,7 +200,7 @@ public class KnnIndexTester {
|
||||||
knnIndexer.numSegments(result);
|
knnIndexer.numSegments(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cmdLineArgs.queryVectors() != null) {
|
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
|
||||||
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
|
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
|
||||||
knnSearcher.runSearch(result);
|
knnSearcher.runSearch(result);
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ class KnnIndexer {
|
||||||
private final Path docsPath;
|
private final Path docsPath;
|
||||||
private final Path indexPath;
|
private final Path indexPath;
|
||||||
private final VectorEncoding vectorEncoding;
|
private final VectorEncoding vectorEncoding;
|
||||||
private final int dim;
|
private int dim;
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final Codec codec;
|
private final Codec codec;
|
||||||
private final int numDocs;
|
private final int numDocs;
|
||||||
|
@ -106,10 +106,6 @@ class KnnIndexer {
|
||||||
|
|
||||||
iwc.setMaxFullFlushMergeWaitMillis(0);
|
iwc.setMaxFullFlushMergeWaitMillis(0);
|
||||||
|
|
||||||
FieldType fieldType = switch (vectorEncoding) {
|
|
||||||
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
|
|
||||||
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
|
|
||||||
};
|
|
||||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
|
iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
|
||||||
@Override
|
@Override
|
||||||
public boolean isEnabled(String component) {
|
public boolean isEnabled(String component) {
|
||||||
|
@ -137,7 +133,26 @@ class KnnIndexer {
|
||||||
FileChannel in = FileChannel.open(docsPath)
|
FileChannel in = FileChannel.open(docsPath)
|
||||||
) {
|
) {
|
||||||
long docsPathSizeInBytes = in.size();
|
long docsPathSizeInBytes = in.size();
|
||||||
if (docsPathSizeInBytes % ((long) dim * vectorEncoding.byteSize) != 0) {
|
int offsetByteSize = 0;
|
||||||
|
if (dim == -1) {
|
||||||
|
offsetByteSize = 4;
|
||||||
|
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
|
||||||
|
if (bytesRead < 4) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
|
||||||
|
);
|
||||||
|
}
|
||||||
|
dim = preamble.getInt(0);
|
||||||
|
if (dim <= 0) {
|
||||||
|
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FieldType fieldType = switch (vectorEncoding) {
|
||||||
|
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
|
||||||
|
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
|
||||||
|
};
|
||||||
|
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
|
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
|
||||||
);
|
);
|
||||||
|
@ -150,7 +165,7 @@ class KnnIndexer {
|
||||||
vectorEncoding.byteSize
|
vectorEncoding.byteSize
|
||||||
);
|
);
|
||||||
|
|
||||||
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding);
|
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
|
||||||
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
|
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
|
||||||
AtomicInteger numDocsIndexed = new AtomicInteger();
|
AtomicInteger numDocsIndexed = new AtomicInteger();
|
||||||
List<Future<?>> threads = new ArrayList<>();
|
List<Future<?>> threads = new ArrayList<>();
|
||||||
|
@ -271,21 +286,24 @@ class KnnIndexer {
|
||||||
|
|
||||||
static class VectorReader {
|
static class VectorReader {
|
||||||
final float[] target;
|
final float[] target;
|
||||||
|
final int offsetByteSize;
|
||||||
final ByteBuffer bytes;
|
final ByteBuffer bytes;
|
||||||
final FileChannel input;
|
final FileChannel input;
|
||||||
long position;
|
long position;
|
||||||
|
|
||||||
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) throws IOException {
|
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int offsetByteSize) throws IOException {
|
||||||
|
// check if dim is set as preamble in the file:
|
||||||
int bufferSize = dim * vectorEncoding.byteSize;
|
int bufferSize = dim * vectorEncoding.byteSize;
|
||||||
if (input.size() % ((long) dim * vectorEncoding.byteSize) != 0) {
|
if (input.size() % ((long) dim * vectorEncoding.byteSize + offsetByteSize) != 0) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size()
|
"vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return new VectorReader(input, dim, bufferSize);
|
return new VectorReader(input, dim, bufferSize, offsetByteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
|
VectorReader(FileChannel input, int dim, int bufferSize, int offsetByteSize) throws IOException {
|
||||||
|
this.offsetByteSize = offsetByteSize;
|
||||||
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
|
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
this.input = input;
|
this.input = input;
|
||||||
this.target = new float[dim];
|
this.target = new float[dim];
|
||||||
|
@ -293,14 +311,14 @@ class KnnIndexer {
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset() throws IOException {
|
void reset() throws IOException {
|
||||||
position = 0;
|
position = offsetByteSize;
|
||||||
input.position(position);
|
input.position(position);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void readNext() throws IOException {
|
private void readNext() throws IOException {
|
||||||
int bytesRead = Channels.readFromFileChannel(this.input, position, bytes);
|
int bytesRead = Channels.readFromFileChannel(this.input, position, bytes);
|
||||||
if (bytesRead < bytes.capacity()) {
|
if (bytesRead < bytes.capacity()) {
|
||||||
position = 0;
|
position = offsetByteSize;
|
||||||
bytes.position(0);
|
bytes.position(0);
|
||||||
// wrap around back to the start of the file if we hit the end:
|
// wrap around back to the start of the file if we hit the end:
|
||||||
logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again");
|
logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again");
|
||||||
|
@ -312,7 +330,7 @@ class KnnIndexer {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
position += bytesRead;
|
position += bytesRead + offsetByteSize;
|
||||||
bytes.position(0);
|
bytes.position(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.store.FSDirectory;
|
import org.apache.lucene.store.FSDirectory;
|
||||||
import org.apache.lucene.store.MMapDirectory;
|
import org.apache.lucene.store.MMapDirectory;
|
||||||
|
import org.elasticsearch.common.io.Channels;
|
||||||
import org.elasticsearch.core.PathUtils;
|
import org.elasticsearch.core.PathUtils;
|
||||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
import org.elasticsearch.search.profile.query.QueryProfiler;
|
import org.elasticsearch.search.profile.query.QueryProfiler;
|
||||||
|
@ -87,7 +88,7 @@ class KnnSearcher {
|
||||||
private final int efSearch;
|
private final int efSearch;
|
||||||
private final int nProbe;
|
private final int nProbe;
|
||||||
private final KnnIndexTester.IndexType indexType;
|
private final KnnIndexTester.IndexType indexType;
|
||||||
private final int dim;
|
private int dim;
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final VectorEncoding vectorEncoding;
|
private final VectorEncoding vectorEncoding;
|
||||||
private final float overSamplingFactor;
|
private final float overSamplingFactor;
|
||||||
|
@ -117,6 +118,7 @@ class KnnSearcher {
|
||||||
TopDocs[] results = new TopDocs[numQueryVectors];
|
TopDocs[] results = new TopDocs[numQueryVectors];
|
||||||
int[][] resultIds = new int[numQueryVectors][];
|
int[][] resultIds = new int[numQueryVectors][];
|
||||||
long elapsed, totalCpuTimeMS, totalVisited = 0;
|
long elapsed, totalCpuTimeMS, totalVisited = 0;
|
||||||
|
int offsetByteSize = 0;
|
||||||
try (
|
try (
|
||||||
FileChannel input = FileChannel.open(queryPath);
|
FileChannel input = FileChannel.open(queryPath);
|
||||||
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
|
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
|
||||||
|
@ -128,7 +130,19 @@ class KnnSearcher {
|
||||||
+ " bytes, assuming vector count is "
|
+ " bytes, assuming vector count is "
|
||||||
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
|
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
|
||||||
);
|
);
|
||||||
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding);
|
if (dim == -1) {
|
||||||
|
offsetByteSize = 4;
|
||||||
|
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
int bytesRead = Channels.readFromFileChannel(input, 0, preamble);
|
||||||
|
if (bytesRead < 4) {
|
||||||
|
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" does not contain a valid dims?");
|
||||||
|
}
|
||||||
|
dim = preamble.getInt(0);
|
||||||
|
if (dim <= 0) {
|
||||||
|
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
|
||||||
long startNS;
|
long startNS;
|
||||||
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
|
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
|
||||||
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
|
@ -191,7 +205,7 @@ class KnnSearcher {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.info("checking results");
|
logger.info("checking results");
|
||||||
int[][] nn = getOrCalculateExactNN();
|
int[][] nn = getOrCalculateExactNN(offsetByteSize);
|
||||||
finalResults.avgRecall = checkResults(resultIds, nn, topK);
|
finalResults.avgRecall = checkResults(resultIds, nn, topK);
|
||||||
finalResults.qps = (1000f * numQueryVectors) / elapsed;
|
finalResults.qps = (1000f * numQueryVectors) / elapsed;
|
||||||
finalResults.avgLatency = (float) elapsed / numQueryVectors;
|
finalResults.avgLatency = (float) elapsed / numQueryVectors;
|
||||||
|
@ -200,7 +214,7 @@ class KnnSearcher {
|
||||||
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
|
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[][] getOrCalculateExactNN() throws IOException {
|
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
|
||||||
// look in working directory for cached nn file
|
// look in working directory for cached nn file
|
||||||
String hash = Integer.toString(
|
String hash = Integer.toString(
|
||||||
Objects.hash(
|
Objects.hash(
|
||||||
|
@ -228,9 +242,9 @@ class KnnSearcher {
|
||||||
// checking low-precision recall
|
// checking low-precision recall
|
||||||
int[][] nn;
|
int[][] nn;
|
||||||
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
|
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
|
||||||
nn = computeExactNNByte(queryPath);
|
nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
|
||||||
} else {
|
} else {
|
||||||
nn = computeExactNN(queryPath);
|
nn = computeExactNN(queryPath, vectorFileOffsetBytes);
|
||||||
}
|
}
|
||||||
writeExactNN(nn, nnPath);
|
writeExactNN(nn, nnPath);
|
||||||
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
|
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
|
||||||
|
@ -356,12 +370,17 @@ class KnnSearcher {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[][] computeExactNN(Path queryPath) throws IOException {
|
private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
|
||||||
int[][] result = new int[numQueryVectors][];
|
int[][] result = new int[numQueryVectors][];
|
||||||
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
try (FileChannel qIn = FileChannel.open(queryPath)) {
|
try (FileChannel qIn = FileChannel.open(queryPath)) {
|
||||||
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.FLOAT32);
|
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(
|
||||||
|
qIn,
|
||||||
|
dim,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
|
vectorFileOffsetBytes
|
||||||
|
);
|
||||||
for (int i = 0; i < numQueryVectors; i++) {
|
for (int i = 0; i < numQueryVectors; i++) {
|
||||||
float[] queryVector = new float[dim];
|
float[] queryVector = new float[dim];
|
||||||
queryReader.next(queryVector);
|
queryReader.next(queryVector);
|
||||||
|
@ -373,12 +392,12 @@ class KnnSearcher {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[][] computeExactNNByte(Path queryPath) throws IOException {
|
private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
|
||||||
int[][] result = new int[numQueryVectors][];
|
int[][] result = new int[numQueryVectors][];
|
||||||
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
List<Callable<Void>> tasks = new ArrayList<>();
|
List<Callable<Void>> tasks = new ArrayList<>();
|
||||||
try (FileChannel qIn = FileChannel.open(queryPath)) {
|
try (FileChannel qIn = FileChannel.open(queryPath)) {
|
||||||
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE);
|
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE, vectorFileOffsetBytes);
|
||||||
for (int i = 0; i < numQueryVectors; i++) {
|
for (int i = 0; i < numQueryVectors; i++) {
|
||||||
byte[] queryVector = new byte[dim];
|
byte[] queryVector = new byte[dim];
|
||||||
queryReader.next(queryVector);
|
queryReader.next(queryVector);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue