Build cagra index (iter1)

This commit is contained in:
Mayya Sharipova 2025-06-23 15:55:59 -04:00
parent 2ac22b3d96
commit 2c106fc9ba
3 changed files with 94 additions and 18 deletions

View file

@ -77,9 +77,10 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
} }
/** Tells whether the platform supports cuvs. */ /** Tells whether the platform supports cuvs. */
public static boolean supported() { public static CuVSResources cuVSResourcesOrNull() {
try (var resources = CuVSResources.create()) { try {
return true; var resources = CuVSResources.create();
return resources;
} catch (UnsupportedOperationException uoe) { } catch (UnsupportedOperationException uoe) {
var msg = uoe.getMessage() == null ? "" : ": " + uoe.getMessage(); var msg = uoe.getMessage() == null ? "" : ": " + uoe.getMessage();
LOG.warn("cuvs is not supported on this platform or java version" + msg); LOG.warn("cuvs is not supported on this platform or java version" + msg);
@ -89,6 +90,6 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
} }
LOG.warn("Exception occurred during creation of cuvs resources. " + t); LOG.warn("Exception occurred during creation of cuvs resources. " + t);
} }
return false; return null;
} }
} }

View file

@ -7,6 +7,10 @@
package org.elasticsearch.xpack.gpu.codec; package org.elasticsearch.xpack.gpu.codec;
import com.nvidia.cuvs.CagraIndex;
import com.nvidia.cuvs.CagraIndexParams;
import com.nvidia.cuvs.CuVSResources;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
@ -15,33 +19,44 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.elasticsearch.common.lucene.store.IndexOutputOutputStream;
import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.IOUtils;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/** /**
* Writer for GPU-accelerated vectors. * Writer for GPU-accelerated vectors.
*/ */
public class GPUVectorsWriter extends KnnVectorsWriter { public class GPUVectorsWriter extends KnnVectorsWriter {
private static final Logger logger = LogManager.getLogger(GPUVectorsWriter.class);
private final List<FieldWriter> fieldWriters = new ArrayList<>(); private final List<FieldWriter> fieldWriters = new ArrayList<>();
private final IndexOutput gpuIdx; private final IndexOutput gpuIdx;
private final IndexOutput gpuMeta; private final IndexOutput gpuMeta;
private final FlatVectorsWriter rawVectorDelegate; private final FlatVectorsWriter rawVectorDelegate;
private final SegmentWriteState segmentWriteState; private final SegmentWriteState segmentWriteState;
private final CuVSResources cuVSResources;
@SuppressWarnings("this-escape") @SuppressWarnings("this-escape")
public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException { public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
this.cuVSResources = GPUVectorsFormat.cuVSResourcesOrNull();
if (cuVSResources == null) {
throw new IllegalArgumentException("GPU based vector search is not supported on this platform or java version");
}
this.segmentWriteState = state; this.segmentWriteState = state;
this.rawVectorDelegate = rawVectorDelegate; this.rawVectorDelegate = rawVectorDelegate;
final String metaFileName = IndexFileNames.segmentFileName( final String metaFileName = IndexFileNames.segmentFileName(
@ -95,15 +110,61 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
@Override @Override
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap); rawVectorDelegate.flush(maxDoc, sortMap);
// TODO: implement the case when sortMap != null
for (FieldWriter fieldWriter : fieldWriters) { for (FieldWriter fieldWriter : fieldWriters) {
// TODO: Implement GPU-specific vector merging instead of bogus implementation // TODO: can we use MemorySegment instead of passing array of vectors
float[][] vectors = fieldWriter.delegate.getVectors().toArray(float[][]::new);
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES); long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
var vectors = fieldWriter.delegate.getVectors(); try {
for (int i = 0; i < vectors.size(); i++) { writeGPUIndex(fieldWriter.fieldInfo.getVectorSimilarityFunction(), vectors);
gpuIdx.writeVInt(0); long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
} catch (IOException e) {
throw e;
} catch (Throwable t) {
throw new IOException("Failed to write GPU index: ", t);
} }
long dataLength = gpuIdx.getFilePointer() - dataOffset; }
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength); }
private void writeGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable {
// https://github.com/rapidsai/cuvs/issues/666
// TODO: do Lucene HNSW index write here
if (vectors.length < 2) {
throw new IllegalStateException("Must be more than [1] vectors in a segment");
}
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
};
// TODO: expose cagra index params of intermediate graph degree, graph degre, algorithm, NNDescentNumIterations
CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
.withMetric(distanceType)
.build();
// build index on GPU
long startTime = System.nanoTime();
var index = CagraIndex.newBuilder(cuVSResources).withDataset(vectors).withIndexParams(params).build();
if (logger.isDebugEnabled()) {
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, vectors.length);
}
// TODO: do serialization through MemorySegment instead of a temp file
// serialize index for CPU consumption
startTime = System.nanoTime();
var gpuIndexOutputStream = new IndexOutputOutputStream(gpuIdx);
try {
index.serialize(gpuIndexOutputStream);
if (logger.isDebugEnabled()) {
logger.debug("Carga index serialized in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
}
} finally {
index.destroyIndex();
} }
} }
@ -111,14 +172,27 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState); rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
// TODO: Implement GPU-specific vector merging instead of bogus implementation FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES); float[][] vectors = new float[vectorValues.size()][vectorValues.dimension()];
for (int i = 0; i < floatVectorValues.size(); i++) {
gpuIdx.writeVInt(0); int cnt = 0;
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
vectors[cnt++] = vectorValues.vectorValue(iter.index());
}
// TODO: check that the number corresponds for deleted documents
assert (vectorValues.size() == cnt);
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
try {
writeGPUIndex(fieldInfo.getVectorSimilarityFunction(), vectors);
long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldInfo, dataOffset, dataLength);
} catch (IOException e) {
throw e;
} catch (Throwable t) {
throw new IOException("Failed to write GPU index: ", t);
} }
long dataLength = gpuIdx.getFilePointer() - dataOffset;
writeMeta(fieldInfo, dataOffset, dataLength);
} else { } else {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState); rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
} }
@ -157,6 +231,7 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
@Override @Override
public final void close() throws IOException { public final void close() throws IOException {
IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx); IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx);
cuVSResources.close();
} }
@Override @Override

View file

@ -25,7 +25,7 @@ public class GPUVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
@BeforeClass @BeforeClass
public static void beforeClass() { public static void beforeClass() {
assumeTrue("cuvs not supported", GPUVectorsFormat.supported()); assumeTrue("cuvs not supported", GPUVectorsFormat.cuVSResourcesOrNull() != null);
} }
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat()); static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat());