Build cagra index (iter1)

This commit is contained in:
Mayya Sharipova 2025-06-23 15:55:59 -04:00
parent 2ac22b3d96
commit 2c106fc9ba
3 changed files with 94 additions and 18 deletions

View file

@ -77,9 +77,10 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
}
/** Tells whether the platform supports cuvs. */
public static boolean supported() {
try (var resources = CuVSResources.create()) {
return true;
public static CuVSResources cuVSResourcesOrNull() {
try {
var resources = CuVSResources.create();
return resources;
} catch (UnsupportedOperationException uoe) {
var msg = uoe.getMessage() == null ? "" : ": " + uoe.getMessage();
LOG.warn("cuvs is not supported on this platform or java version" + msg);
@ -89,6 +90,6 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
}
LOG.warn("Exception occurred during creation of cuvs resources. " + t);
}
return false;
return null;
}
}

View file

@ -7,6 +7,10 @@
package org.elasticsearch.xpack.gpu.codec;
import com.nvidia.cuvs.CagraIndex;
import com.nvidia.cuvs.CagraIndexParams;
import com.nvidia.cuvs.CuVSResources;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -15,33 +19,44 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexOutput;
import org.elasticsearch.common.lucene.store.IndexOutputOutputStream;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/**
* Writer for GPU-accelerated vectors.
*/
public class GPUVectorsWriter extends KnnVectorsWriter {
private static final Logger logger = LogManager.getLogger(GPUVectorsWriter.class);
private final List<FieldWriter> fieldWriters = new ArrayList<>();
private final IndexOutput gpuIdx;
private final IndexOutput gpuMeta;
private final FlatVectorsWriter rawVectorDelegate;
private final SegmentWriteState segmentWriteState;
private final CuVSResources cuVSResources;
@SuppressWarnings("this-escape")
public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
this.cuVSResources = GPUVectorsFormat.cuVSResourcesOrNull();
if (cuVSResources == null) {
throw new IllegalArgumentException("GPU based vector search is not supported on this platform or java version");
}
this.segmentWriteState = state;
this.rawVectorDelegate = rawVectorDelegate;
final String metaFileName = IndexFileNames.segmentFileName(
@ -95,15 +110,61 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
@Override
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
// TODO: implement the case when sortMap != null
for (FieldWriter fieldWriter : fieldWriters) {
// TODO: Implement GPU-specific vector merging instead of bogus implementation
// TODO: can we use MemorySegment instead of passing array of vectors
float[][] vectors = fieldWriter.delegate.getVectors().toArray(float[][]::new);
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
var vectors = fieldWriter.delegate.getVectors();
for (int i = 0; i < vectors.size(); i++) {
gpuIdx.writeVInt(0);
try {
writeGPUIndex(fieldWriter.fieldInfo.getVectorSimilarityFunction(), vectors);
long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
} catch (IOException e) {
throw e;
} catch (Throwable t) {
throw new IOException("Failed to write GPU index: ", t);
}
long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
}
}
private void writeGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable {
// https://github.com/rapidsai/cuvs/issues/666
// TODO: do Lucene HNSW index write here
if (vectors.length < 2) {
throw new IllegalStateException("Must be more than [1] vectors in a segment");
}
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
};
// TODO: expose cagra index params of intermediate graph degree, graph degre, algorithm, NNDescentNumIterations
CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
.withMetric(distanceType)
.build();
// build index on GPU
long startTime = System.nanoTime();
var index = CagraIndex.newBuilder(cuVSResources).withDataset(vectors).withIndexParams(params).build();
if (logger.isDebugEnabled()) {
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, vectors.length);
}
// TODO: do serialization through MemorySegment instead of a temp file
// serialize index for CPU consumption
startTime = System.nanoTime();
var gpuIndexOutputStream = new IndexOutputOutputStream(gpuIdx);
try {
index.serialize(gpuIndexOutputStream);
if (logger.isDebugEnabled()) {
logger.debug("Carga index serialized in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
}
} finally {
index.destroyIndex();
}
}
@ -111,14 +172,27 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
// TODO: Implement GPU-specific vector merging instead of bogus implementation
FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
for (int i = 0; i < floatVectorValues.size(); i++) {
gpuIdx.writeVInt(0);
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
float[][] vectors = new float[vectorValues.size()][vectorValues.dimension()];
int cnt = 0;
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
vectors[cnt++] = vectorValues.vectorValue(iter.index());
}
// TODO: check that the number corresponds for deleted documents
assert (vectorValues.size() == cnt);
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
try {
writeGPUIndex(fieldInfo.getVectorSimilarityFunction(), vectors);
long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldInfo, dataOffset, dataLength);
} catch (IOException e) {
throw e;
} catch (Throwable t) {
throw new IOException("Failed to write GPU index: ", t);
}
long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldInfo, dataOffset, dataLength);
} else {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
}
@ -157,6 +231,7 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
@Override
public final void close() throws IOException {
IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx);
cuVSResources.close();
}
@Override

View file

@ -25,7 +25,7 @@ public class GPUVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
@BeforeClass
public static void beforeClass() {
assumeTrue("cuvs not supported", GPUVectorsFormat.supported());
assumeTrue("cuvs not supported", GPUVectorsFormat.cuVSResourcesOrNull() != null);
}
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat());