[IVF] Improve the format of the tmp file written during merging (#129828)

This commit separe vector and docIds on the tmp file.
This commit is contained in:
Ignacio Vera 2025-06-23 14:44:00 +02:00 committed by GitHub
parent b1741e8a96
commit 72b488cfa9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -28,6 +28,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.SuppressForbidden;
@ -237,25 +238,49 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws IOException { private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
final int numVectors; final int numVectors;
String tempRawVectorsFileName = null; String tempRawVectorsFileName = null;
String docsFileName = null;
boolean success = false; boolean success = false;
// build a float vector values with random access. In order to do that we dump the vectors to // build a float vector values with random access. In order to do that we dump the vectors to
// a temporary file // a temporary file and if the segment is not dense, the docs to another file/
// and write the docID follow by the vector try (
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { IndexOutput vectorsOut = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivfvec_", IOContext.DEFAULT)
tempRawVectorsFileName = out.getName(); ) {
tempRawVectorsFileName = vectorsOut.getName();
FloatVectorValues mergedFloatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
// if the segment is dense, we don't need to do anything with docIds.
boolean dense = mergedFloatVectorValues.size() == mergeState.segmentInfo.maxDoc();
try (
IndexOutput docsOut = dense
? null
: mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivfdoc_", IOContext.DEFAULT)
) {
if (docsOut != null) {
docsFileName = docsOut.getName();
}
// TODO do this better, we shouldn't have to write to a temp file, we should be able to // TODO do this better, we shouldn't have to write to a temp file, we should be able to
// to just from the merged vector values, the tricky part is the random access. // to just from the merged vector values, the tricky part is the random access.
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); numVectors = writeFloatVectorValues(fieldInfo, docsOut, vectorsOut, mergedFloatVectorValues);
CodecUtil.writeFooter(out); CodecUtil.writeFooter(vectorsOut);
if (docsOut != null) {
CodecUtil.writeFooter(docsOut);
}
success = true; success = true;
}
} finally { } finally {
if (success == false && tempRawVectorsFileName != null) { if (success == false) {
if (tempRawVectorsFileName != null) {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
} }
if (docsFileName != null) {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, docsFileName);
} }
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) { }
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; }
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); try (
IndexInput vectors = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT);
IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT)
) {
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
success = false; success = false;
long centroidOffset; long centroidOffset;
long centroidLength; long centroidLength;
@ -263,10 +288,10 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
int numCentroids; int numCentroids;
IndexOutput centroidTemp = null; IndexOutput centroidTemp = null;
CentroidAssignments centroidAssignments; CentroidAssignments centroidAssignments;
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
try { try {
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
centroidTempName = centroidTemp.getName(); centroidTempName = centroidTemp.getName();
centroidAssignments = calculateAndWriteCentroids( centroidAssignments = calculateAndWriteCentroids(
fieldInfo, fieldInfo,
floatVectorValues, floatVectorValues,
@ -318,28 +343,34 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
} }
} finally { } finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
}
} finally {
if (docsFileName != null) {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions( org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(
mergeState.segmentInfo.dir, mergeState.segmentInfo.dir,
tempRawVectorsFileName, tempRawVectorsFileName,
centroidTempName docsFileName
); );
} } else {
} finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
} }
} }
}
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) { private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput docs, IndexInput vectors, int numVectors)
throws IOException {
if (numVectors == 0) { if (numVectors == 0) {
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension()); return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
} }
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES; final long vectorLength = (long) Float.BYTES * fieldInfo.getVectorDimension();
final float[] vector = new float[fieldInfo.getVectorDimension()]; final float[] vector = new float[fieldInfo.getVectorDimension()];
final RandomAccessInput randomDocs = docs == null ? null : docs.randomAccessSlice(0, docs.length());
return new FloatVectorValues() { return new FloatVectorValues() {
@Override @Override
public float[] vectorValue(int ord) throws IOException { public float[] vectorValue(int ord) throws IOException {
randomAccessInput.seek(ord * length + Integer.BYTES); vectors.seek(ord * vectorLength);
randomAccessInput.readFloats(vector, 0, vector.length); vectors.readFloats(vector, 0, vector.length);
return vector; return vector;
} }
@ -360,9 +391,11 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
@Override @Override
public int ordToDoc(int ord) { public int ordToDoc(int ord) {
if (randomDocs == null) {
return ord;
}
try { try {
randomAccessInput.seek(ord * length); return randomDocs.readInt((long) ord * Integer.BYTES);
return randomAccessInput.readInt();
} catch (IOException e) { } catch (IOException e) {
throw new UncheckedIOException(e); throw new UncheckedIOException(e);
} }
@ -370,17 +403,22 @@ public abstract class IVFVectorsWriter extends KnnVectorsWriter {
}; };
} }
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues) private static int writeFloatVectorValues(
throws IOException { FieldInfo fieldInfo,
IndexOutput docsOut,
IndexOutput vectorsOut,
FloatVectorValues floatVectorValues
) throws IOException {
int numVectors = 0; int numVectors = 0;
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
numVectors++; numVectors++;
float[] vector = floatVectorValues.vectorValue(iterator.index()); buffer.asFloatBuffer().put(floatVectorValues.vectorValue(iterator.index()));
out.writeInt(iterator.docID()); vectorsOut.writeBytes(buffer.array(), buffer.array().length);
buffer.asFloatBuffer().put(vector); if (docsOut != null) {
out.writeBytes(buffer.array(), buffer.array().length); docsOut.writeInt(iterator.docID());
}
} }
return numVectors; return numVectors;
} }