Reapply "Adds unused lower level ivf knn query (#127852)" (#128003) (#128052)

* Reapply "Adds unused lower level ivf knn query (#127852)" (#128003)

This reverts commit 648d74bf97.

* Fixing tests
This commit is contained in:
Benjamin Trent 2025-05-14 11:18:25 -04:00 committed by GitHub
parent 3b0ef09e5b
commit 85f466207e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 1467 additions and 11 deletions

View file

@ -447,18 +447,9 @@ tests:
- class: org.elasticsearch.indices.stats.IndexStatsIT
method: testThrottleStats
issue: https://github.com/elastic/elasticsearch/issues/126359
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
method: testRandomWithFilter
issue: https://github.com/elastic/elasticsearch/issues/127963
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
method: testSearchBoost
issue: https://github.com/elastic/elasticsearch/issues/127969
- class: org.elasticsearch.packaging.test.DockerTests
method: test040JavaUsesTheOsProvidedKeystore
issue: https://github.com/elastic/elasticsearch/issues/127437
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
method: testFindFewer
issue: https://github.com/elastic/elasticsearch/issues/128002
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityRestIT
method: testTaskCancellation
issue: https://github.com/elastic/elasticsearch/issues/128009

View file

@ -32,6 +32,7 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
import java.io.IOException;
import java.util.function.IntPredicate;
@ -243,8 +244,11 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
return;
}
// TODO add new ivf search strategy
int nProbe = 10;
if (fieldInfo.getVectorDimension() != target.length) {
throw new IllegalArgumentException(
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
);
}
float percentFiltered = 1f;
if (acceptDocs instanceof BitSet bitSet) {
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
@ -257,6 +261,13 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
}
return visitedDocs.getAndSet(docId) == false;
};
final int nProbe;
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
nProbe = ivfSearchStrategy.getNProbe();
} else {
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
nProbe = 10;
}
FieldEntry entry = fields.get(fieldInfo.number);
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(

View file

@ -0,0 +1,198 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.vectors;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.elasticsearch.search.profile.query.QueryProfiler;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
protected final String field;
protected final int nProbe;
protected final int k;
protected final Query filter;
protected final KnnSearchStrategy searchStrategy;
protected int vectorOpsCount;
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, Query filter) {
this.field = field;
this.nProbe = nProbe;
this.k = k;
this.filter = filter;
this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery) o;
return k == that.k
&& Objects.equals(field, that.field)
&& Objects.equals(filter, that.filter)
&& Objects.equals(nProbe, that.nProbe);
}
@Override
public int hashCode() {
return Objects.hash(field, k, filter, nProbe);
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
vectorOpsCount = 0;
IndexReader reader = indexSearcher.getIndexReader();
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
if (rewritten.getClass() == MatchNoDocsQuery.class) {
return rewritten;
}
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
vectorOpsCount = (int) topK.totalHits.value();
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return new KnnScoreDocQuery(topK.scoreDocs, reader);
}
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();
if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
}
Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
final int cost = acceptDocs.cardinality();
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
}
abstract TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager
) throws IOException;
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new IVFCollectorManager(k, nProbe);
}
@Override
public final void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}
BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return bitSetIterator.getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}
static class IVFCollectorManager implements KnnCollectorManager {
private final int k;
private final int nprobe;
IVFCollectorManager(int k, int nprobe) {
this.k = k;
this.nprobe = nprobe;
}
@Override
public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
return new TopKnnCollector(k, visitedLimit, new IVFKnnSearchStrategy(nprobe));
}
}
}

View file

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.vectors;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;
import java.io.IOException;
import java.util.Arrays;
/** A {@link IVFKnnFloatVectorQuery} that uses the IVF search strategy. */
public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {
private final float[] query;
/**
* Creates a new {@link IVFKnnFloatVectorQuery} with the given parameters.
* @param field the field to search
* @param query the query vector
* @param k the number of nearest neighbors to return
* @param filter the filter to apply to the results
* @param nProbe the number of probes to use for the IVF search strategy
*/
public IVFKnnFloatVectorQuery(String field, float[] query, int k, Query filter, int nProbe) {
super(field, nProbe, k, filter);
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
if (nProbe < 1) {
throw new IllegalArgumentException("nProbe must be at least 1, got: " + nProbe);
}
this.query = query;
}
@Override
public String toString(String field) {
StringBuilder buffer = new StringBuilder();
buffer.append(getClass().getSimpleName())
.append(":")
.append(this.field)
.append("[")
.append(query[0])
.append(",...]")
.append("[")
.append(k)
.append("]");
if (this.filter != null) {
buffer.append("[").append(this.filter).append("]");
}
return buffer.toString();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (super.equals(o) == false) return false;
IVFKnnFloatVectorQuery that = (IVFKnnFloatVectorQuery) o;
return Arrays.equals(query, that.query);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(query);
return result;
}
@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager
) throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
LeafReader reader = context.reader();
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(reader, field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS;
}
reader.searchNearestVectors(field, query, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.vectors;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import java.util.Objects;
public class IVFKnnSearchStrategy extends KnnSearchStrategy {
private final int nProbe;
IVFKnnSearchStrategy(int nProbe) {
this.nProbe = nProbe;
}
public int getNProbe() {
return nProbe;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o;
return nProbe == that.nProbe;
}
@Override
public int hashCode() {
return Objects.hashCode(nProbe);
}
@Override
public void nextVectorsBlock() {
// do nothing
}
}

View file

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.vectors;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.VectorUtil;
import java.io.IOException;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat;
public class IVFKnnFloatVectorQueryTests extends AbstractIVFKnnVectorQueryTestCase {
@Override
IVFKnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe) {
return new IVFKnnFloatVectorQuery(field, query, k, queryFilter, nProbe);
}
@Override
float[] randomVector(int dim) {
float[] vector = new float[dim];
for (int i = 0; i < dim; i++) {
vector[i] = randomFloat();
}
VectorUtil.l2normalize(vector);
return vector;
}
@Override
Field getKnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnFloatVectorField(name, vector, similarityFunction);
}
@Override
Field getKnnVectorField(String name, float[] vector) {
return new KnnFloatVectorField(name, vector);
}
public void testToString() throws IOException {
try (
Directory indexStore = getIndexStore("field", new float[] { 0, 1 }, new float[] { 1, 2 }, new float[] { 0, 0 });
IndexReader reader = DirectoryReader.open(indexStore)
) {
AbstractIVFKnnVectorQuery query = getKnnVectorQuery("field", new float[] { 0.0f, 1.0f }, 10);
assertEquals("IVFKnnFloatVectorQuery:field[0.0,...][10]", query.toString("ignored"));
assertDocScoreQueryToString(query.rewrite(newSearcher(reader)));
// test with filter
Query filter = new TermQuery(new Term("id", "text"));
query = getKnnVectorQuery("field", new float[] { 0.0f, 1.0f }, 10, filter);
assertEquals("IVFKnnFloatVectorQuery:field[0.0,...][10][id:text]", query.toString("ignored"));
}
}
}