mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
* Reapply "Adds unused lower level ivf knn query (#127852)" (#128003)
This reverts commit 648d74bf97
.
* Fixing tests
This commit is contained in:
parent
3b0ef09e5b
commit
85f466207e
7 changed files with 1467 additions and 11 deletions
|
@ -447,18 +447,9 @@ tests:
|
||||||
- class: org.elasticsearch.indices.stats.IndexStatsIT
|
- class: org.elasticsearch.indices.stats.IndexStatsIT
|
||||||
method: testThrottleStats
|
method: testThrottleStats
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/126359
|
issue: https://github.com/elastic/elasticsearch/issues/126359
|
||||||
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
|
|
||||||
method: testRandomWithFilter
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/127963
|
|
||||||
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
|
|
||||||
method: testSearchBoost
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/127969
|
|
||||||
- class: org.elasticsearch.packaging.test.DockerTests
|
- class: org.elasticsearch.packaging.test.DockerTests
|
||||||
method: test040JavaUsesTheOsProvidedKeystore
|
method: test040JavaUsesTheOsProvidedKeystore
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/127437
|
issue: https://github.com/elastic/elasticsearch/issues/127437
|
||||||
- class: org.elasticsearch.search.vectors.IVFKnnFloatVectorQueryTests
|
|
||||||
method: testFindFewer
|
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/128002
|
|
||||||
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityRestIT
|
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityRestIT
|
||||||
method: testTaskCancellation
|
method: testTaskCancellation
|
||||||
issue: https://github.com/elastic/elasticsearch/issues/128009
|
issue: https://github.com/elastic/elasticsearch/issues/128009
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
import org.elasticsearch.core.IOUtils;
|
import org.elasticsearch.core.IOUtils;
|
||||||
|
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.function.IntPredicate;
|
import java.util.function.IntPredicate;
|
||||||
|
@ -243,8 +244,11 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
|
||||||
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
|
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// TODO add new ivf search strategy
|
if (fieldInfo.getVectorDimension() != target.length) {
|
||||||
int nProbe = 10;
|
throw new IllegalArgumentException(
|
||||||
|
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
|
||||||
|
);
|
||||||
|
}
|
||||||
float percentFiltered = 1f;
|
float percentFiltered = 1f;
|
||||||
if (acceptDocs instanceof BitSet bitSet) {
|
if (acceptDocs instanceof BitSet bitSet) {
|
||||||
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
|
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
|
||||||
|
@ -257,6 +261,13 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
return visitedDocs.getAndSet(docId) == false;
|
return visitedDocs.getAndSet(docId) == false;
|
||||||
};
|
};
|
||||||
|
final int nProbe;
|
||||||
|
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
|
||||||
|
nProbe = ivfSearchStrategy.getNProbe();
|
||||||
|
} else {
|
||||||
|
// TODO calculate nProbe given the number of centroids vs. number of vectors for given `k`
|
||||||
|
nProbe = 10;
|
||||||
|
}
|
||||||
|
|
||||||
FieldEntry entry = fields.get(fieldInfo.number);
|
FieldEntry entry = fields.get(fieldInfo.number);
|
||||||
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
|
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.search.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.search.BooleanClause;
|
||||||
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.search.FieldExistsQuery;
|
||||||
|
import org.apache.lucene.search.FilteredDocIdSetIterator;
|
||||||
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
import org.apache.lucene.search.KnnCollector;
|
||||||
|
import org.apache.lucene.search.MatchNoDocsQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.search.QueryVisitor;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.lucene.search.ScoreMode;
|
||||||
|
import org.apache.lucene.search.Scorer;
|
||||||
|
import org.apache.lucene.search.TaskExecutor;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.TopDocsCollector;
|
||||||
|
import org.apache.lucene.search.TopKnnCollector;
|
||||||
|
import org.apache.lucene.search.Weight;
|
||||||
|
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||||
|
import org.apache.lucene.search.knn.KnnSearchStrategy;
|
||||||
|
import org.apache.lucene.util.BitSet;
|
||||||
|
import org.apache.lucene.util.BitSetIterator;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.elasticsearch.search.profile.query.QueryProfiler;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
|
abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerProvider {
|
||||||
|
|
||||||
|
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
|
||||||
|
protected final String field;
|
||||||
|
protected final int nProbe;
|
||||||
|
protected final int k;
|
||||||
|
protected final Query filter;
|
||||||
|
protected final KnnSearchStrategy searchStrategy;
|
||||||
|
protected int vectorOpsCount;
|
||||||
|
|
||||||
|
protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, Query filter) {
|
||||||
|
this.field = field;
|
||||||
|
this.nProbe = nProbe;
|
||||||
|
this.k = k;
|
||||||
|
this.filter = filter;
|
||||||
|
this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visit(QueryVisitor visitor) {
|
||||||
|
if (visitor.acceptField(field)) {
|
||||||
|
visitor.visitLeaf(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery) o;
|
||||||
|
return k == that.k
|
||||||
|
&& Objects.equals(field, that.field)
|
||||||
|
&& Objects.equals(filter, that.filter)
|
||||||
|
&& Objects.equals(nProbe, that.nProbe);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(field, k, filter, nProbe);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
||||||
|
vectorOpsCount = 0;
|
||||||
|
IndexReader reader = indexSearcher.getIndexReader();
|
||||||
|
|
||||||
|
final Weight filterWeight;
|
||||||
|
if (filter != null) {
|
||||||
|
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(filter, BooleanClause.Occur.FILTER)
|
||||||
|
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
|
||||||
|
.build();
|
||||||
|
Query rewritten = indexSearcher.rewrite(booleanQuery);
|
||||||
|
if (rewritten.getClass() == MatchNoDocsQuery.class) {
|
||||||
|
return rewritten;
|
||||||
|
}
|
||||||
|
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||||
|
} else {
|
||||||
|
filterWeight = null;
|
||||||
|
}
|
||||||
|
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
|
||||||
|
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
|
||||||
|
List<LeafReaderContext> leafReaderContexts = reader.leaves();
|
||||||
|
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
|
||||||
|
for (LeafReaderContext context : leafReaderContexts) {
|
||||||
|
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
|
||||||
|
}
|
||||||
|
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
|
||||||
|
|
||||||
|
// Merge sort the results
|
||||||
|
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||||
|
vectorOpsCount = (int) topK.totalHits.value();
|
||||||
|
if (topK.scoreDocs.length == 0) {
|
||||||
|
return new MatchNoDocsQuery();
|
||||||
|
}
|
||||||
|
return new KnnScoreDocQuery(topK.scoreDocs, reader);
|
||||||
|
}
|
||||||
|
|
||||||
|
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
|
||||||
|
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
|
||||||
|
if (ctx.docBase > 0) {
|
||||||
|
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||||
|
scoreDoc.doc += ctx.docBase;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
|
||||||
|
final LeafReader reader = ctx.reader();
|
||||||
|
final Bits liveDocs = reader.getLiveDocs();
|
||||||
|
|
||||||
|
if (filterWeight == null) {
|
||||||
|
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scorer scorer = filterWeight.scorer(ctx);
|
||||||
|
if (scorer == null) {
|
||||||
|
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
|
||||||
|
final int cost = acceptDocs.cardinality();
|
||||||
|
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract TopDocs approximateSearch(
|
||||||
|
LeafReaderContext context,
|
||||||
|
Bits acceptDocs,
|
||||||
|
int visitedLimit,
|
||||||
|
KnnCollectorManager knnCollectorManager
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
|
||||||
|
return new IVFCollectorManager(k, nProbe);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void profile(QueryProfiler queryProfiler) {
|
||||||
|
queryProfiler.addVectorOpsCount(vectorOpsCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
|
||||||
|
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
|
||||||
|
// If we already have a BitSet and no deletions, reuse the BitSet
|
||||||
|
return bitSetIterator.getBitSet();
|
||||||
|
} else {
|
||||||
|
// Create a new BitSet from matching and live docs
|
||||||
|
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator) {
|
||||||
|
@Override
|
||||||
|
protected boolean match(int doc) {
|
||||||
|
return liveDocs == null || liveDocs.get(doc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return BitSet.of(filterIterator, maxDoc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class IVFCollectorManager implements KnnCollectorManager {
|
||||||
|
private final int k;
|
||||||
|
private final int nprobe;
|
||||||
|
|
||||||
|
IVFCollectorManager(int k, int nprobe) {
|
||||||
|
this.k = k;
|
||||||
|
this.nprobe = nprobe;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
|
||||||
|
return new TopKnnCollector(k, visitedLimit, new IVFKnnSearchStrategy(nprobe));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.search.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.search.KnnCollector;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.lucene.search.knn.KnnCollectorManager;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/** A {@link IVFKnnFloatVectorQuery} that uses the IVF search strategy. */
|
||||||
|
public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery {
|
||||||
|
|
||||||
|
private final float[] query;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new {@link IVFKnnFloatVectorQuery} with the given parameters.
|
||||||
|
* @param field the field to search
|
||||||
|
* @param query the query vector
|
||||||
|
* @param k the number of nearest neighbors to return
|
||||||
|
* @param filter the filter to apply to the results
|
||||||
|
* @param nProbe the number of probes to use for the IVF search strategy
|
||||||
|
*/
|
||||||
|
public IVFKnnFloatVectorQuery(String field, float[] query, int k, Query filter, int nProbe) {
|
||||||
|
super(field, nProbe, k, filter);
|
||||||
|
if (k < 1) {
|
||||||
|
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||||
|
}
|
||||||
|
if (nProbe < 1) {
|
||||||
|
throw new IllegalArgumentException("nProbe must be at least 1, got: " + nProbe);
|
||||||
|
}
|
||||||
|
this.query = query;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(String field) {
|
||||||
|
StringBuilder buffer = new StringBuilder();
|
||||||
|
buffer.append(getClass().getSimpleName())
|
||||||
|
.append(":")
|
||||||
|
.append(this.field)
|
||||||
|
.append("[")
|
||||||
|
.append(query[0])
|
||||||
|
.append(",...]")
|
||||||
|
.append("[")
|
||||||
|
.append(k)
|
||||||
|
.append("]");
|
||||||
|
if (this.filter != null) {
|
||||||
|
buffer.append("[").append(this.filter).append("]");
|
||||||
|
}
|
||||||
|
return buffer.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (super.equals(o) == false) return false;
|
||||||
|
IVFKnnFloatVectorQuery that = (IVFKnnFloatVectorQuery) o;
|
||||||
|
return Arrays.equals(query, that.query);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
int result = super.hashCode();
|
||||||
|
result = 31 * result + Arrays.hashCode(query);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TopDocs approximateSearch(
|
||||||
|
LeafReaderContext context,
|
||||||
|
Bits acceptDocs,
|
||||||
|
int visitedLimit,
|
||||||
|
KnnCollectorManager knnCollectorManager
|
||||||
|
) throws IOException {
|
||||||
|
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context);
|
||||||
|
LeafReader reader = context.reader();
|
||||||
|
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
|
||||||
|
if (floatVectorValues == null) {
|
||||||
|
FloatVectorValues.checkField(reader, field);
|
||||||
|
return NO_RESULTS;
|
||||||
|
}
|
||||||
|
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
|
||||||
|
return NO_RESULTS;
|
||||||
|
}
|
||||||
|
reader.searchNearestVectors(field, query, knnCollector, acceptDocs);
|
||||||
|
TopDocs results = knnCollector.topDocs();
|
||||||
|
return results != null ? results : NO_RESULTS;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.search.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.knn.KnnSearchStrategy;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class IVFKnnSearchStrategy extends KnnSearchStrategy {
|
||||||
|
private final int nProbe;
|
||||||
|
|
||||||
|
IVFKnnSearchStrategy(int nProbe) {
|
||||||
|
this.nProbe = nProbe;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNProbe() {
|
||||||
|
return nProbe;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o;
|
||||||
|
return nProbe == that.nProbe;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hashCode(nProbe);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void nextVectorsBlock() {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,69 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.search.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.KnnFloatVectorField;
|
||||||
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat;
|
||||||
|
|
||||||
|
public class IVFKnnFloatVectorQueryTests extends AbstractIVFKnnVectorQueryTestCase {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
IVFKnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe) {
|
||||||
|
return new IVFKnnFloatVectorQuery(field, query, k, queryFilter, nProbe);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
float[] randomVector(int dim) {
|
||||||
|
float[] vector = new float[dim];
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
vector[i] = randomFloat();
|
||||||
|
}
|
||||||
|
VectorUtil.l2normalize(vector);
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Field getKnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||||
|
return new KnnFloatVectorField(name, vector, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Field getKnnVectorField(String name, float[] vector) {
|
||||||
|
return new KnnFloatVectorField(name, vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToString() throws IOException {
|
||||||
|
try (
|
||||||
|
Directory indexStore = getIndexStore("field", new float[] { 0, 1 }, new float[] { 1, 2 }, new float[] { 0, 0 });
|
||||||
|
IndexReader reader = DirectoryReader.open(indexStore)
|
||||||
|
) {
|
||||||
|
AbstractIVFKnnVectorQuery query = getKnnVectorQuery("field", new float[] { 0.0f, 1.0f }, 10);
|
||||||
|
assertEquals("IVFKnnFloatVectorQuery:field[0.0,...][10]", query.toString("ignored"));
|
||||||
|
|
||||||
|
assertDocScoreQueryToString(query.rewrite(newSearcher(reader)));
|
||||||
|
|
||||||
|
// test with filter
|
||||||
|
Query filter = new TermQuery(new Term("id", "text"));
|
||||||
|
query = getKnnVectorQuery("field", new float[] { 0.0f, 1.0f }, 10, filter);
|
||||||
|
assertEquals("IVFKnnFloatVectorQuery:field[0.0,...][10][id:text]", query.toString("ignored"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue