mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-27 17:10:22 -04:00
* Reapply "Adds unused lower level ivf knn query (#127852)" (#128003)
This reverts commit 648d74bf97
.
* Fixing tests
This commit is contained in:
parent
3b0ef09e5b
commit
85f466207e
7 changed files with 1467 additions and 11 deletions
|
@ -447,18 +447,9 @@ tests:
|
|||
- class: org.elasticsearch.indices.stats.IndexStatsIT
|
||||
method: testThrottleStats
|
||||
issue: https://github.com/elastic/elasticsearch/issues/126359
|
||||
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
|
||||
method: testRandomWithFilter
|
||||
issue: https://github.com/elastic/elasticsearch/issues/127963
|
||||
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
|
||||
method: testSearchBoost
|
||||
issue: https://github.com/elastic/elasticsearch/issues/127969
|
||||
- class: org.elasticsearch.packaging.test.DockerTests
|
||||
method: test040JavaUsesTheOsProvidedKeystore
|
||||
issue: https://github.com/elastic/elasticsearch/issues/127437
|
||||
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
|
||||
method: testFindFewer
|
||||
issue: https://github.com/elastic/elasticsearch/issues/128002
|
||||
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityRestIT
|
||||
method: testTaskCancellation
|
||||
issue: https://github.com/elastic/elasticsearch/issues/128009
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.elasticsearch.core.IOUtils;
|
||||
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.IntPredicate;
|
||||
|
@ -243,8 +244,11 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
|
|||
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
|
||||
return;
|
||||
}
|
||||
// TODO add new ivf search strategy
|
||||
int nProbe = 10;
|
||||
if (fieldInfo.getVectorDimension() != target.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
|
||||
);
|
||||
}
|
||||
float percentFiltered = 1f;
|
||||
if (acceptDocs instanceof BitSet bitSet) {
|
||||
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
|
||||
|
@ -257,6 +261,13 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
return visitedDocs.getAndSet(docId) == false;
|
||||
};
|
||||
final int nProbe;
|
||||
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
|
||||
nProbe = ivfSearchStrategy.getNProbe();
|
||||
} else {
|
||||
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
|
||||
nProbe = 10;
|
||||
}
|
||||
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.search.vectors;
|
||||
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.FieldExistsQuery;
|
||||
import org.apache.lucene.search.FilteredDocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.MatchNoDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.QueryVisitor;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TopDocsCollector;
|
||||
import org.apache.lucene.search.TopKnnCollector;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.search.knn.KnnSearchStrategy;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.elasticsearch.search.profile.query.QueryProfiler;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
|
||||
|
||||
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||
|
||||
protected final String field;
|
||||
protected final int nProbe;
|
||||
protected final int k;
|
||||
protected final Query filter;
|
||||
protected final KnnSearchStrategy searchStrategy;
|
||||
protected int vectorOpsCount;
|
||||
|
||||
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, Query filter) {
|
||||
this.field = field;
|
||||
this.nProbe = nProbe;
|
||||
this.k = k;
|
||||
this.filter = filter;
|
||||
this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery) o;
|
||||
return k == that.k
|
||||
&& Objects.equals(field, that.field)
|
||||
&& Objects.equals(filter, that.filter)
|
||||
&& Objects.equals(nProbe, that.nProbe);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, k, filter, nProbe);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
||||
vectorOpsCount = 0;
|
||||
IndexReader reader = indexSearcher.getIndexReader();
|
||||
|
||||
final Weight filterWeight;
|
||||
if (filter != null) {
|
||||
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(filter, BooleanClause.Occur.FILTER)
|
||||
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||
if (rewritten.getClass() == MatchNoDocsQuery.class) {
|
||||
return rewritten;
|
||||
}
|
||||
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||
} else {
|
||||
filterWeight = null;
|
||||
}
|
||||
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
|
||||
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
|
||||
List<LeafReaderContext> leafReaderContexts = reader.leaves();
|
||||
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
|
||||
for (LeafReaderContext context : leafReaderContexts) {
|
||||
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
|
||||
}
|
||||
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
|
||||
|
||||
// Merge sort the results
|
||||
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||
vectorOpsCount = (int) topK.totalHits.value();
|
||||
if (topK.scoreDocs.length == 0) {
|
||||
return new MatchNoDocsQuery();
|
||||
}
|
||||
return new KnnScoreDocQuery(topK.scoreDocs, reader);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
|
||||
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
|
||||
final LeafReader reader = ctx.reader();
|
||||
final Bits liveDocs = reader.getLiveDocs();
|
||||
|
||||
if (filterWeight == null) {
|
||||
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(ctx);
|
||||
if (scorer == null) {
|
||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||
}
|
||||
|
||||
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
|
||||
final int cost = acceptDocs.cardinality();
|
||||
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
|
||||
}
|
||||
|
||||
abstract TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager
|
||||
) throws IOException;
|
||||
|
||||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||
return new IVFCollectorManager(k, nProbe);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void profile(QueryProfiler queryProfiler) {
|
||||
queryProfiler.addVectorOpsCount(vectorOpsCount);
|
||||
}
|
||||
|
||||
BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
|
||||
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
|
||||
// If we already have a BitSet and no deletions, reuse the BitSet
|
||||
return bitSetIterator.getBitSet();
|
||||
} else {
|
||||
// Create a new BitSet from matching and live docs
|
||||
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator) {
|
||||
@Override
|
||||
protected boolean match(int doc) {
|
||||
return liveDocs == null || liveDocs.get(doc);
|
||||
}
|
||||
};
|
||||
return BitSet.of(filterIterator, maxDoc);
|
||||
}
|
||||
}
|
||||
|
||||
static class IVFCollectorManager implements KnnCollectorManager {
|
||||
private final int k;
|
||||
private final int nprobe;
|
||||
|
||||
IVFCollectorManager(int k, int nprobe) {
|
||||
this.k = k;
|
||||
this.nprobe = nprobe;
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
|
||||
return new TopKnnCollector(k, visitedLimit, new IVFKnnSearchStrategy(nprobe));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.search.vectors;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
/** A {@link IVFKnnFloatVectorQuery} that uses the IVF search strategy. */
|
||||
public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {
|
||||
|
||||
private final float[] query;
|
||||
|
||||
/**
|
||||
* Creates a new {@link IVFKnnFloatVectorQuery} with the given parameters.
|
||||
* @param field the field to search
|
||||
* @param query the query vector
|
||||
* @param k the number of nearest neighbors to return
|
||||
* @param filter the filter to apply to the results
|
||||
* @param nProbe the number of probes to use for the IVF search strategy
|
||||
*/
|
||||
public IVFKnnFloatVectorQuery(String field, float[] query, int k, Query filter, int nProbe) {
|
||||
super(field, nProbe, k, filter);
|
||||
if (k < 1) {
|
||||
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||
}
|
||||
if (nProbe < 1) {
|
||||
throw new IllegalArgumentException("nProbe must be at least 1, got: " + nProbe);
|
||||
}
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
StringBuilder buffer = new StringBuilder();
|
||||
buffer.append(getClass().getSimpleName())
|
||||
.append(":")
|
||||
.append(this.field)
|
||||
.append("[")
|
||||
.append(query[0])
|
||||
.append(",...]")
|
||||
.append("[")
|
||||
.append(k)
|
||||
.append("]");
|
||||
if (this.filter != null) {
|
||||
buffer.append("[").append(this.filter).append("]");
|
||||
}
|
||||
return buffer.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (super.equals(o) == false) return false;
|
||||
IVFKnnFloatVectorQuery that = (IVFKnnFloatVectorQuery) o;
|
||||
return Arrays.equals(query, that.query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = super.hashCode();
|
||||
result = 31 * result + Arrays.hashCode(query);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TopDocs approximateSearch(
|
||||
LeafReaderContext context,
|
||||
Bits acceptDocs,
|
||||
int visitedLimit,
|
||||
KnnCollectorManager knnCollectorManager
|
||||
) throws IOException {
|
||||
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
|
||||
LeafReader reader = context.reader();
|
||||
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
|
||||
if (floatVectorValues == null) {
|
||||
FloatVectorValues.checkField(reader, field);
|
||||
return NO_RESULTS;
|
||||
}
|
||||
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
reader.searchNearestVectors(field, query, knnCollector, acceptDocs);
|
||||
TopDocs results = knnCollector.topDocs();
|
||||
return results != null ? results : NO_RESULTS;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.search.vectors;
|
||||
|
||||
import org.apache.lucene.search.knn.KnnSearchStrategy;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class IVFKnnSearchStrategy extends KnnSearchStrategy {
|
||||
private final int nProbe;
|
||||
|
||||
IVFKnnSearchStrategy(int nProbe) {
|
||||
this.nProbe = nProbe;
|
||||
}
|
||||
|
||||
public int getNProbe() {
|
||||
return nProbe;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o;
|
||||
return nProbe == that.nProbe;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(nProbe);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void nextVectorsBlock() {
|
||||
// do nothing
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.search.vectors;
|
||||
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat;
|
||||
|
||||
public class IVFKnnFloatVectorQueryTests extends AbstractIVFKnnVectorQueryTestCase {
|
||||
|
||||
@Override
|
||||
IVFKnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe) {
|
||||
return new IVFKnnFloatVectorQuery(field, query, k, queryFilter, nProbe);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] randomVector(int dim) {
|
||||
float[] vector = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
vector[i] = randomFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(vector);
|
||||
return vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
Field getKnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnFloatVectorField(name, vector, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field getKnnVectorField(String name, float[] vector) {
|
||||
return new KnnFloatVectorField(name, vector);
|
||||
}
|
||||
|
||||
public void testToString() throws IOException {
|
||||
try (
|
||||
Directory indexStore = getIndexStore("field", new float[] { 0, 1 }, new float[] { 1, 2 }, new float[] { 0, 0 });
|
||||
IndexReader reader = DirectoryReader.open(indexStore)
|
||||
) {
|
||||
AbstractIVFKnnVectorQuery query = getKnnVectorQuery("field", new float[] { 0.0f, 1.0f }, 10);
|
||||
assertEquals("IVFKnnFloatVectorQuery:field[0.0,...][10]", query.toString("ignored"));
|
||||
|
||||
assertDocScoreQueryToString(query.rewrite(newSearcher(reader)));
|
||||
|
||||
// test with filter
|
||||
Query filter = new TermQuery(new Term("id", "text"));
|
||||
query = getKnnVectorQuery("field", new float[] { 0.0f, 1.0f }, 10, filter);
|
||||
assertEquals("IVFKnnFloatVectorQuery:field[0.0,...][10][id:text]", query.toString("ignored"));
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue