Use collector managers in DfsPhase (#96689)

This change switches the part of the DfsPhase that runs a knn query
from a single thread collector to a potentially multi-threaded collector
manager and optional profiling wrapper. This doesn't in itself enable
concurrent execution but lays the foundation for a switch of the executor
in the ContextIndexSearcher to actually enable this.
This commit is contained in:
Christoph Büscher 2023-07-03 13:46:17 +02:00 committed by GitHub
parent 7f9b9a5677
commit 5eeaecd9cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 190 additions and 57 deletions

View file

@ -0,0 +1,5 @@
pr: 96689
summary: Use a collector manager in DfsPhase Knn Search
area: Search
type: enhancement
issues: []

View file

@ -111,7 +111,6 @@ public class DfsProfilerIT extends ESIntegTestCase {
CollectorResult result = queryProfileShardResult.getCollectorResult(); CollectorResult result = queryProfileShardResult.getCollectorResult();
assertThat(result.getName(), is(not(emptyOrNullString()))); assertThat(result.getName(), is(not(emptyOrNullString())));
assertThat(result.getTime(), greaterThan(0L)); assertThat(result.getTime(), greaterThan(0L));
assertThat(result.getTime(), greaterThan(0L));
} }
ProfileResult statsResult = searchProfileDfsPhaseResult.getDfsShardResult(); ProfileResult statsResult = searchProfileDfsPhaseResult.getDfsShardResult();
assertThat(statsResult.getQueryName(), equalTo("statistics")); assertThat(statsResult.getQueryName(), equalTo("statistics"));

View file

@ -16,19 +16,20 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TopScoreDocCollector;
import org.elasticsearch.index.query.ParsedQuery; import org.elasticsearch.index.query.ParsedQuery;
import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.Timer; import org.elasticsearch.search.profile.Timer;
import org.elasticsearch.search.profile.dfs.DfsProfiler; import org.elasticsearch.search.profile.dfs.DfsProfiler;
import org.elasticsearch.search.profile.dfs.DfsTimingType; import org.elasticsearch.search.profile.dfs.DfsTimingType;
import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector; import org.elasticsearch.search.profile.query.ProfileCollectorManager;
import org.elasticsearch.search.profile.query.InternalProfileCollectorManager;
import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.query.SingleThreadCollectorManager;
import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
@ -43,7 +44,6 @@ import java.util.Map;
/** /**
* DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase. * DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
* The additional information is used to better compare the scores coming from all the shards, which depend on local factors (e.g. idf). * The additional information is used to better compare the scores coming from all the shards, which depend on local factors (e.g. idf).
*
* When a kNN search is provided alongside the query, the DFS phase is also used to gather the top k candidates from each shard. Then the * When a kNN search is provided alongside the query, the DFS phase is also used to gather the top k candidates from each shard. Then the
* global top k hits are passed on to the query phase. * global top k hits are passed on to the query phase.
*/ */
@ -189,24 +189,27 @@ public class DfsPhase {
List<DfsKnnResults> knnResults = new ArrayList<>(knnVectorQueryBuilders.size()); List<DfsKnnResults> knnResults = new ArrayList<>(knnVectorQueryBuilders.size());
for (int i = 0; i < knnSearch.size(); i++) { for (int i = 0; i < knnSearch.size(); i++) {
Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query(); Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query();
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.get(i).k(), Integer.MAX_VALUE); knnResults.add(singleKnnSearch(knnQuery, knnSearch.get(i).k(), context.getProfilers(), context.searcher()));
CollectorManager<Collector, Void> collectorManager = new SingleThreadCollectorManager(topScoreDocCollector);
if (context.getProfilers() != null) {
InternalProfileCollectorManager ipcm = new InternalProfileCollectorManager(
new InternalProfileCollector(collectorManager.newCollector(), CollectorResult.REASON_SEARCH_TOP_HITS)
);
QueryProfiler knnProfiler = context.getProfilers().getDfsProfiler().addQueryProfiler(ipcm);
collectorManager = ipcm;
// Set the current searcher profiler to gather query profiling information for gathering top K docs
context.searcher().setProfiler(knnProfiler);
}
context.searcher().search(knnQuery, collectorManager);
knnResults.add(new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs));
}
// Set profiler back after running KNN searches
if (context.getProfilers() != null) {
context.searcher().setProfiler(context.getProfilers().getCurrentQueryProfiler());
} }
context.dfsResult().knnResults(knnResults); context.dfsResult().knnResults(knnResults);
} }
static DfsKnnResults singleKnnSearch(Query knnQuery, int k, Profilers profilers, ContextIndexSearcher searcher) throws IOException {
CollectorManager<? extends Collector, TopDocs> cm = TopScoreDocCollector.createSharedManager(k, null, Integer.MAX_VALUE);
if (profilers != null) {
ProfileCollectorManager<TopDocs> ipcm = new ProfileCollectorManager<>(cm, CollectorResult.REASON_SEARCH_TOP_HITS);
QueryProfiler knnProfiler = profilers.getDfsProfiler().addQueryProfiler(ipcm);
cm = ipcm;
// Set the current searcher profiler to gather query profiling information for gathering top K docs
searcher.setProfiler(knnProfiler);
}
TopDocs topDocs = searcher.search(knnQuery, cm);
// Set profiler back after running KNN searches
if (profilers != null) {
searcher.setProfiler(profilers.getCurrentQueryProfiler());
}
return new DfsKnnResults(topDocs.scoreDocs);
}
} }

View file

@ -12,7 +12,7 @@ import org.elasticsearch.search.profile.AbstractProfileBreakdown;
import org.elasticsearch.search.profile.ProfileResult; import org.elasticsearch.search.profile.ProfileResult;
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult; import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
import org.elasticsearch.search.profile.Timer; import org.elasticsearch.search.profile.Timer;
import org.elasticsearch.search.profile.query.InternalProfileCollectorManager; import org.elasticsearch.search.profile.query.ProfileCollectorManager;
import org.elasticsearch.search.profile.query.QueryProfileShardResult; import org.elasticsearch.search.profile.query.QueryProfileShardResult;
import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.search.profile.query.QueryProfiler;
@ -51,7 +51,7 @@ public class DfsProfiler extends AbstractProfileBreakdown<DfsTimingType> {
return newTimer; return newTimer;
} }
public QueryProfiler addQueryProfiler(InternalProfileCollectorManager collectorManager) { public QueryProfiler addQueryProfiler(ProfileCollectorManager<?> collectorManager) {
QueryProfiler queryProfiler = new QueryProfiler(); QueryProfiler queryProfiler = new QueryProfiler();
queryProfiler.setCollectorManager(collectorManager::getCollectorTree); queryProfiler.setCollectorManager(collectorManager::getCollectorTree);
knnQueryProfilers.add(queryProfiler); knnQueryProfilers.add(queryProfiler);

View file

@ -13,6 +13,7 @@ import org.apache.lucene.search.CollectorManager;
import java.io.IOException; import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -21,15 +22,15 @@ import java.util.stream.Collectors;
* in an {@link InternalProfileCollector}. It delegates all the profiling to the generated collectors via {@link #getCollectorTree()} * in an {@link InternalProfileCollector}. It delegates all the profiling to the generated collectors via {@link #getCollectorTree()}
* and joins them up when its {@link #reduce} method is called. The profile result can * and joins them up when its {@link #reduce} method is called. The profile result can
*/ */
public final class ProfileCollectorManager implements CollectorManager<InternalProfileCollector, Void> { public final class ProfileCollectorManager<T> implements CollectorManager<InternalProfileCollector, T> {
private final CollectorManager<Collector, ?> collectorManager; private final CollectorManager<Collector, T> collectorManager;
private final String reason; private final String reason;
private CollectorResult collectorTree; private CollectorResult collectorTree;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public ProfileCollectorManager(CollectorManager<? extends Collector, ?> collectorManager, String reason) { public ProfileCollectorManager(CollectorManager<? extends Collector, T> collectorManager, String reason) {
this.collectorManager = (CollectorManager<Collector, ?>) collectorManager; this.collectorManager = (CollectorManager<Collector, T>) collectorManager;
this.reason = reason; this.reason = reason;
} }
@ -38,22 +39,26 @@ public final class ProfileCollectorManager implements CollectorManager<InternalP
return new InternalProfileCollector(collectorManager.newCollector(), reason); return new InternalProfileCollector(collectorManager.newCollector(), reason);
} }
public Void reduce(Collection<InternalProfileCollector> profileCollectors) throws IOException { public T reduce(Collection<InternalProfileCollector> profileCollectors) throws IOException {
assert profileCollectors.size() > 0 : "at least one collector expected";
List<Collector> unwrapped = profileCollectors.stream() List<Collector> unwrapped = profileCollectors.stream()
.map(InternalProfileCollector::getWrappedCollector) .map(InternalProfileCollector::getWrappedCollector)
.collect(Collectors.toList()); .collect(Collectors.toList());
collectorManager.reduce(unwrapped); T returnValue = collectorManager.reduce(unwrapped);
List<CollectorResult> resultsPerProfiler = profileCollectors.stream() List<CollectorResult> resultsPerProfiler = profileCollectors.stream()
.map(ipc -> ipc.getCollectorTree()) .map(ipc -> ipc.getCollectorTree())
.collect(Collectors.toList()); .collect(Collectors.toList());
this.collectorTree = new CollectorResult(this.getClass().getSimpleName(), "segment_search", 0, resultsPerProfiler);
return null; long totalTime = resultsPerProfiler.stream().map(CollectorResult::getTime).reduce(0L, Long::sum);
String collectorName = resultsPerProfiler.get(0).getName();
this.collectorTree = new CollectorResult(collectorName, reason, totalTime, Collections.emptyList());
return returnValue;
} }
public CollectorResult getCollectorTree() { public CollectorResult getCollectorTree() {
if (this.collectorTree == null) { if (this.collectorTree == null) {
throw new IllegalStateException("A collectorTree hasn't been set yet, call reduce() before attempting to retrieve it"); throw new IllegalStateException("A collectorTree hasn't been set yet. Call reduce() before attempting to retrieve it");
} }
return this.collectorTree; return this.collectorTree;
} }

View file

@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.search.dfs;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ThreadPoolExecutor;
public class DfsPhaseTests extends ESTestCase {
ThreadPoolExecutor threadPoolExecutor;
private TestThreadPool threadPool;
@Before
public final void init() {
int numThreads = randomIntBetween(2, 4);
threadPool = new TestThreadPool(DfsPhaseTests.class.getName());
threadPoolExecutor = EsExecutors.newFixed(
"test",
numThreads,
10,
EsExecutors.daemonThreadFactory("test"),
threadPool.getThreadContext(),
randomBoolean()
);
}
@After
public void cleanup() {
threadPoolExecutor.shutdown();
terminate(threadPool);
}
public void testSingleKnnSearch() throws IOException {
try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
int numDocs = randomIntBetween(900, 1000);
for (int i = 0; i < numDocs; i++) {
Document d = new Document();
d.add(new KnnFloatVectorField("float_vector", new float[] { i, 0, 0 }));
w.addDocument(d);
}
w.flush();
IndexReader reader = w.getReader();
ContextIndexSearcher searcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
randomBoolean(),
this.threadPoolExecutor
) {
@Override
protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
// get a thread per segment
return slices(leaves, 1, 1);
}
};
Query query = new KnnFloatVectorQuery("float_vector", new float[] { 0, 0, 0 }, numDocs, null);
int k = 10;
// run without profiling enabled
DfsKnnResults dfsKnnResults = DfsPhase.singleKnnSearch(query, k, null, searcher);
assertEquals(k, dfsKnnResults.scoreDocs().length);
// run with profiling enabled
Profilers profilers = new Profilers(searcher);
dfsKnnResults = DfsPhase.singleKnnSearch(query, k, profilers, searcher);
assertEquals(k, dfsKnnResults.scoreDocs().length);
SearchProfileDfsPhaseResult searchProfileDfsPhaseResult = profilers.getDfsProfiler().buildDfsPhaseResults();
List<QueryProfileShardResult> queryProfileShardResult = searchProfileDfsPhaseResult.getQueryProfileShardResult();
assertNotNull(queryProfileShardResult);
CollectorResult collectorResult = queryProfileShardResult.get(0).getCollectorResult();
assertEquals("SimpleTopScoreDocCollector", (collectorResult.getName()));
assertEquals("search_top_hits", (collectorResult.getReason()));
assertTrue(collectorResult.getTime() > 0);
List<CollectorResult> children = collectorResult.getCollectorResults();
if (children.size() > 0) {
long totalTime = 0L;
for (CollectorResult child : children) {
assertEquals("SimpleTopScoreDocCollector", (child.getName()));
assertEquals("search_top_hits", (child.getReason()));
totalTime += child.getTime();
}
assertEquals(totalTime, collectorResult.getTime());
}
reader.close();
}
}
}

View file

@ -12,8 +12,6 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.sandbox.search.ProfilerCollectorResult;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
@ -27,6 +25,7 @@ import org.apache.lucene.util.SetOnce;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -48,29 +47,41 @@ public class ProfileCollectorManagerTests extends ESTestCase {
*/ */
public void testBasic() throws IOException { public void testBasic() throws IOException {
final SetOnce<Boolean> reduceCalled = new SetOnce<>(); final SetOnce<Boolean> reduceCalled = new SetOnce<>();
ProfileCollectorManager pcm = new ProfileCollectorManager(new CollectorManager<>() { ProfileCollectorManager<Integer> pcm = new ProfileCollectorManager<>(new CollectorManager<TestCollector, Integer>() {
private static int counter = 0; private int counter = 0;
@Override @Override
public Collector newCollector() { public TestCollector newCollector() {
return new TestCollector(counter++); return new TestCollector(counter++);
} }
@Override @Override
public Void reduce(Collection<Collector> collectors) { public Integer reduce(Collection<TestCollector> collectors) {
reduceCalled.set(true); reduceCalled.set(true);
return null; return counter;
} }
}, CollectorResult.REASON_SEARCH_TOP_HITS); }, CollectorResult.REASON_SEARCH_TOP_HITS);
for (int i = 0; i < randomIntBetween(5, 10); i++) { int runs = randomIntBetween(5, 10);
InternalProfileCollector internalProfileCollector = pcm.newCollector(); List<InternalProfileCollector> collectors = new ArrayList<>();
assertEquals(i, ((TestCollector) internalProfileCollector.getWrappedCollector()).id); for (int i = 0; i < runs; i++) {
collectors.add(pcm.newCollector());
assertEquals(i, ((TestCollector) collectors.get(i).getWrappedCollector()).id);
} }
pcm.reduce(Collections.emptyList()); Integer returnValue = pcm.reduce(collectors);
assertEquals(runs, returnValue.intValue());
assertTrue(reduceCalled.get()); assertTrue(reduceCalled.get());
} }
public void testReduceEmpty() {
ProfileCollectorManager<TopDocs> pcm = new ProfileCollectorManager<>(
TopScoreDocCollector.createSharedManager(10, null, 1000),
CollectorResult.REASON_SEARCH_TOP_HITS
);
AssertionError ae = expectThrows(AssertionError.class, () -> pcm.reduce(Collections.emptyList()));
assertEquals("at least one collector expected", ae.getMessage());
}
/** /**
* This test checks functionality with potentially more than one slice on a real searcher, * This test checks functionality with potentially more than one slice on a real searcher,
* wrapping a {@link TopScoreDocCollector} into {@link ProfileCollectorManager} and checking the * wrapping a {@link TopScoreDocCollector} into {@link ProfileCollectorManager} and checking the
@ -88,7 +99,6 @@ public class ProfileCollectorManagerTests extends ESTestCase {
writer.flush(); writer.flush();
IndexReader reader = writer.getReader(); IndexReader reader = writer.getReader();
IndexSearcher searcher = newSearcher(reader); IndexSearcher searcher = newSearcher(reader);
int numSlices = searcher.getSlices() == null ? 1 : searcher.getSlices().length;
searcher.setSimilarity(new BM25Similarity()); searcher.setSimilarity(new BM25Similarity());
CollectorManager<TopScoreDocCollector, TopDocs> topDocsManager = TopScoreDocCollector.createSharedManager(10, null, 1000); CollectorManager<TopScoreDocCollector, TopDocs> topDocsManager = TopScoreDocCollector.createSharedManager(10, null, 1000);
@ -96,21 +106,15 @@ public class ProfileCollectorManagerTests extends ESTestCase {
assertEquals(numDocs, topDocs.totalHits.value); assertEquals(numDocs, topDocs.totalHits.value);
String profileReason = "profiler_reason"; String profileReason = "profiler_reason";
ProfileCollectorManager profileCollectorManager = new ProfileCollectorManager(topDocsManager, profileReason); ProfileCollectorManager<TopDocs> profileCollectorManager = new ProfileCollectorManager<>(topDocsManager, profileReason);
searcher.search(new MatchAllDocsQuery(), profileCollectorManager); searcher.search(new MatchAllDocsQuery(), profileCollectorManager);
CollectorResult parent = profileCollectorManager.getCollectorTree(); CollectorResult result = profileCollectorManager.getCollectorTree();
assertEquals("ProfileCollectorManager", parent.getName()); assertEquals("profiler_reason", result.getReason());
assertEquals("segment_search", parent.getReason()); assertEquals("SimpleTopScoreDocCollector", result.getName());
assertEquals(0, parent.getTime()); assertTrue(result.getTime() > 0);
List<ProfilerCollectorResult> delegateCollectorResults = parent.getProfiledChildren();
assertEquals(numSlices, delegateCollectorResults.size());
for (ProfilerCollectorResult pcr : delegateCollectorResults) {
assertEquals("SimpleTopScoreDocCollector", pcr.getName());
assertEquals(profileReason, pcr.getReason());
assertTrue(pcr.getTime() > 0);
}
reader.close(); reader.close();
} }
directory.close(); directory.close();