diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java index dd3e59be2646..6de775c4773b 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -65,7 +65,7 @@ public class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader impleme private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES818BinaryQuantizedVectorsReader.class); - private final Map fields = new HashMap<>(); + private final Map fields; private final IndexInput quantizedVectorData; private final FlatVectorsReader rawVectorsReader; private final ES818BinaryFlatVectorsScorer vectorScorer; @@ -77,6 +77,7 @@ public class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader impleme ES818BinaryFlatVectorsScorer vectorsScorer ) throws IOException { super(vectorsScorer); + this.fields = new HashMap<>(); this.vectorScorer = vectorsScorer; this.rawVectorsReader = rawVectorsReader; int versionMeta = -1; @@ -120,6 +121,24 @@ public class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader impleme } } + private ES818BinaryQuantizedVectorsReader(ES818BinaryQuantizedVectorsReader clone, FlatVectorsReader rawVectorsReader) { + super(clone.vectorScorer); + this.rawVectorsReader = rawVectorsReader; + this.vectorScorer = clone.vectorScorer; + this.quantizedVectorData = clone.quantizedVectorData; + this.fields = clone.fields; + } + + // For testing + FlatVectorsReader getRawVectorsReader() { + return rawVectorsReader; + } + + @Override + public FlatVectorsReader getMergeInstance() { + return new ES818BinaryQuantizedVectorsReader(this, rawVectorsReader.getMergeInstance()); + } + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { FieldInfo info = infos.fieldInfo(fieldNumber); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java index 4d9d7e03848c..e74b0aad1272 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java @@ -36,6 +36,11 @@ class MergeReaderWrapper extends FlatVectorsReader implements OffHeapStats { this.mergeReader = mergeReader; } + // For testing + FlatVectorsReader getMainReader() { + return mainReader; + } + @Override public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { return mainReader.getRandomVectorScorer(field, target); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java index 3b187f2241cc..84fafde5af7c 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -35,6 +36,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.SegmentReader; import org.apache.lucene.index.SoftDeletesRetentionMergePolicy; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; @@ -83,6 +85,7 @@ import java.util.OptionalLong; import static java.lang.String.format; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -309,6 +312,43 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat } } + public void testMergeInstance() throws IOException { + checkDirectIOSupported(); + float[] vector = randomVector(10); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField("field", vector, similarityFunction); + try (Directory dir = newFSDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setUseCompoundFile(false))) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(10)); + doc.add(knnField); + w.addDocument(doc); + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + SegmentReader r = (SegmentReader) getOnlyLeafReader(reader); + assertThat(unwrapRawVectorReader("field", r.getVectorReader()), instanceOf(DirectIOLucene99FlatVectorsReader.class)); + assertThat( + unwrapRawVectorReader("field", r.getVectorReader().getMergeInstance()), + instanceOf(Lucene99FlatVectorsReader.class) + ); + } + } + } + } + + private static KnnVectorsReader unwrapRawVectorReader(String fieldName, KnnVectorsReader knnReader) { + if (knnReader instanceof PerFieldKnnVectorsFormat.FieldsReader perField) { + return unwrapRawVectorReader(fieldName, perField.getFieldReader(fieldName)); + } else if (knnReader instanceof ES818BinaryQuantizedVectorsReader bbqReader) { + return unwrapRawVectorReader(fieldName, bbqReader.getRawVectorsReader()); + } else if (knnReader instanceof MergeReaderWrapper mergeReaderWrapper) { + return unwrapRawVectorReader(fieldName, mergeReaderWrapper.getMainReader()); + } else { + return knnReader; + } + } + static Directory newMMapDirectory() throws IOException { Directory dir = new MMapDirectory(createTempDir("ES818BinaryQuantizedVectorsFormatTests")); if (random().nextBoolean()) {