diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 30881a90adfc..844d9348bd82 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -28,6 +28,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.SuppressForbidden; @@ -237,25 +238,49 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws IOException { final int numVectors; String tempRawVectorsFileName = null; + String docsFileName = null; boolean success = false; // build a float vector values with random access. In order to do that we dump the vectors to - // a temporary file - // and write the docID follow by the vector - try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { - tempRawVectorsFileName = out.getName(); - // TODO do this better, we shouldn't have to write to a temp file, we should be able to - // to just from the merged vector values, the tricky part is the random access. - numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); - CodecUtil.writeFooter(out); - success = true; + // a temporary file and if the segment is not dense, the docs to another file/ + try ( + IndexOutput vectorsOut = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivfvec_", IOContext.DEFAULT) + ) { + tempRawVectorsFileName = vectorsOut.getName(); + FloatVectorValues mergedFloatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + // if the segment is dense, we don't need to do anything with docIds. + boolean dense = mergedFloatVectorValues.size() == mergeState.segmentInfo.maxDoc(); + try ( + IndexOutput docsOut = dense + ? null + : mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivfdoc_", IOContext.DEFAULT) + ) { + if (docsOut != null) { + docsFileName = docsOut.getName(); + } + // TODO do this better, we shouldn't have to write to a temp file, we should be able to + // to just from the merged vector values, the tricky part is the random access. + numVectors = writeFloatVectorValues(fieldInfo, docsOut, vectorsOut, mergedFloatVectorValues); + CodecUtil.writeFooter(vectorsOut); + if (docsOut != null) { + CodecUtil.writeFooter(docsOut); + } + success = true; + } } finally { - if (success == false && tempRawVectorsFileName != null) { - org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); + if (success == false) { + if (tempRawVectorsFileName != null) { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); + } + if (docsFileName != null) { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, docsFileName); + } } } - try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) { - float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; - final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); + try ( + IndexInput vectors = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT); + IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT) + ) { + final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors); success = false; long centroidOffset; long centroidLength; @@ -263,10 +288,10 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { int numCentroids; IndexOutput centroidTemp = null; CentroidAssignments centroidAssignments; + float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; try { centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTempName = centroidTemp.getName(); - centroidAssignments = calculateAndWriteCentroids( fieldInfo, floatVectorValues, @@ -318,28 +343,34 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); } } finally { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + } + } finally { + if (docsFileName != null) { org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions( mergeState.segmentInfo.dir, tempRawVectorsFileName, - centroidTempName + docsFileName ); + } else { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); } - } finally { - org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); } } - private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) { + private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput docs, IndexInput vectors, int numVectors) + throws IOException { if (numVectors == 0) { return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension()); } - final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES; + final long vectorLength = (long) Float.BYTES * fieldInfo.getVectorDimension(); final float[] vector = new float[fieldInfo.getVectorDimension()]; + final RandomAccessInput randomDocs = docs == null ? null : docs.randomAccessSlice(0, docs.length()); return new FloatVectorValues() { @Override public float[] vectorValue(int ord) throws IOException { - randomAccessInput.seek(ord * length + Integer.BYTES); - randomAccessInput.readFloats(vector, 0, vector.length); + vectors.seek(ord * vectorLength); + vectors.readFloats(vector, 0, vector.length); return vector; } @@ -360,9 +391,11 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { @Override public int ordToDoc(int ord) { + if (randomDocs == null) { + return ord; + } try { - randomAccessInput.seek(ord * length); - return randomAccessInput.readInt(); + return randomDocs.readInt((long) ord * Integer.BYTES); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -370,17 +403,22 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter { }; } - private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues) - throws IOException { + private static int writeFloatVectorValues( + FieldInfo fieldInfo, + IndexOutput docsOut, + IndexOutput vectorsOut, + FloatVectorValues floatVectorValues + ) throws IOException { int numVectors = 0; final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { numVectors++; - float[] vector = floatVectorValues.vectorValue(iterator.index()); - out.writeInt(iterator.docID()); - buffer.asFloatBuffer().put(vector); - out.writeBytes(buffer.array(), buffer.array().length); + buffer.asFloatBuffer().put(floatVectorValues.vectorValue(iterator.index())); + vectorsOut.writeBytes(buffer.array(), buffer.array().length); + if (docsOut != null) { + docsOut.writeInt(iterator.docID()); + } } return numVectors; }