mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 01:22:26 -04:00
Build cagra index (iter1)
This commit is contained in:
parent
2ac22b3d96
commit
2c106fc9ba
3 changed files with 94 additions and 18 deletions
|
@ -77,9 +77,10 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Tells whether the platform supports cuvs. */
|
/** Tells whether the platform supports cuvs. */
|
||||||
public static boolean supported() {
|
public static CuVSResources cuVSResourcesOrNull() {
|
||||||
try (var resources = CuVSResources.create()) {
|
try {
|
||||||
return true;
|
var resources = CuVSResources.create();
|
||||||
|
return resources;
|
||||||
} catch (UnsupportedOperationException uoe) {
|
} catch (UnsupportedOperationException uoe) {
|
||||||
var msg = uoe.getMessage() == null ? "" : ": " + uoe.getMessage();
|
var msg = uoe.getMessage() == null ? "" : ": " + uoe.getMessage();
|
||||||
LOG.warn("cuvs is not supported on this platform or java version" + msg);
|
LOG.warn("cuvs is not supported on this platform or java version" + msg);
|
||||||
|
@ -89,6 +90,6 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
|
||||||
}
|
}
|
||||||
LOG.warn("Exception occurred during creation of cuvs resources. " + t);
|
LOG.warn("Exception occurred during creation of cuvs resources. " + t);
|
||||||
}
|
}
|
||||||
return false;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,10 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.gpu.codec;
|
package org.elasticsearch.xpack.gpu.codec;
|
||||||
|
|
||||||
|
import com.nvidia.cuvs.CagraIndex;
|
||||||
|
import com.nvidia.cuvs.CagraIndexParams;
|
||||||
|
import com.nvidia.cuvs.CuVSResources;
|
||||||
|
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
@ -15,33 +19,44 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.KnnVectorValues;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.IndexOutput;
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.elasticsearch.common.lucene.store.IndexOutputOutputStream;
|
||||||
import org.elasticsearch.core.IOUtils;
|
import org.elasticsearch.core.IOUtils;
|
||||||
|
import org.elasticsearch.logging.LogManager;
|
||||||
|
import org.elasticsearch.logging.Logger;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writer for GPU-accelerated vectors.
|
* Writer for GPU-accelerated vectors.
|
||||||
*/
|
*/
|
||||||
public class GPUVectorsWriter extends KnnVectorsWriter {
|
public class GPUVectorsWriter extends KnnVectorsWriter {
|
||||||
|
private static final Logger logger = LogManager.getLogger(GPUVectorsWriter.class);
|
||||||
|
|
||||||
private final List<FieldWriter> fieldWriters = new ArrayList<>();
|
private final List<FieldWriter> fieldWriters = new ArrayList<>();
|
||||||
private final IndexOutput gpuIdx;
|
private final IndexOutput gpuIdx;
|
||||||
private final IndexOutput gpuMeta;
|
private final IndexOutput gpuMeta;
|
||||||
private final FlatVectorsWriter rawVectorDelegate;
|
private final FlatVectorsWriter rawVectorDelegate;
|
||||||
private final SegmentWriteState segmentWriteState;
|
private final SegmentWriteState segmentWriteState;
|
||||||
|
private final CuVSResources cuVSResources;
|
||||||
|
|
||||||
@SuppressWarnings("this-escape")
|
@SuppressWarnings("this-escape")
|
||||||
public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
|
public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
|
||||||
|
this.cuVSResources = GPUVectorsFormat.cuVSResourcesOrNull();
|
||||||
|
if (cuVSResources == null) {
|
||||||
|
throw new IllegalArgumentException("GPU based vector search is not supported on this platform or java version");
|
||||||
|
}
|
||||||
this.segmentWriteState = state;
|
this.segmentWriteState = state;
|
||||||
this.rawVectorDelegate = rawVectorDelegate;
|
this.rawVectorDelegate = rawVectorDelegate;
|
||||||
final String metaFileName = IndexFileNames.segmentFileName(
|
final String metaFileName = IndexFileNames.segmentFileName(
|
||||||
|
@ -95,15 +110,61 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
|
||||||
@Override
|
@Override
|
||||||
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||||
rawVectorDelegate.flush(maxDoc, sortMap);
|
rawVectorDelegate.flush(maxDoc, sortMap);
|
||||||
|
// TODO: implement the case when sortMap != null
|
||||||
|
|
||||||
for (FieldWriter fieldWriter : fieldWriters) {
|
for (FieldWriter fieldWriter : fieldWriters) {
|
||||||
// TODO: Implement GPU-specific vector merging instead of bogus implementation
|
// TODO: can we use MemorySegment instead of passing array of vectors
|
||||||
|
float[][] vectors = fieldWriter.delegate.getVectors().toArray(float[][]::new);
|
||||||
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
|
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
|
||||||
var vectors = fieldWriter.delegate.getVectors();
|
try {
|
||||||
for (int i = 0; i < vectors.size(); i++) {
|
writeGPUIndex(fieldWriter.fieldInfo.getVectorSimilarityFunction(), vectors);
|
||||||
gpuIdx.writeVInt(0);
|
long dataLength = gpuIdx.getFilePointer() - dataOffset;
|
||||||
|
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (Throwable t) {
|
||||||
|
throw new IOException("Failed to write GPU index: ", t);
|
||||||
}
|
}
|
||||||
long dataLength = gpuIdx.getFilePointer() - dataOffset;
|
}
|
||||||
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
|
}
|
||||||
|
|
||||||
|
private void writeGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable {
|
||||||
|
// https://github.com/rapidsai/cuvs/issues/666
|
||||||
|
// TODO: do Lucene HNSW index write here
|
||||||
|
if (vectors.length < 2) {
|
||||||
|
throw new IllegalStateException("Must be more than [1] vectors in a segment");
|
||||||
|
}
|
||||||
|
|
||||||
|
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
|
||||||
|
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
|
||||||
|
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
|
||||||
|
case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: expose cagra index params of intermediate graph degree, graph degre, algorithm, NNDescentNumIterations
|
||||||
|
CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
|
||||||
|
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
|
||||||
|
.withMetric(distanceType)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// build index on GPU
|
||||||
|
long startTime = System.nanoTime();
|
||||||
|
var index = CagraIndex.newBuilder(cuVSResources).withDataset(vectors).withIndexParams(params).build();
|
||||||
|
if (logger.isDebugEnabled()) {
|
||||||
|
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, vectors.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: do serialization through MemorySegment instead of a temp file
|
||||||
|
// serialize index for CPU consumption
|
||||||
|
startTime = System.nanoTime();
|
||||||
|
var gpuIndexOutputStream = new IndexOutputOutputStream(gpuIdx);
|
||||||
|
try {
|
||||||
|
index.serialize(gpuIndexOutputStream);
|
||||||
|
if (logger.isDebugEnabled()) {
|
||||||
|
logger.debug("Carga index serialized in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
index.destroyIndex();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,14 +172,27 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
|
||||||
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||||
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||||
// TODO: Implement GPU-specific vector merging instead of bogus implementation
|
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||||
FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
|
||||||
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
|
float[][] vectors = new float[vectorValues.size()][vectorValues.dimension()];
|
||||||
for (int i = 0; i < floatVectorValues.size(); i++) {
|
|
||||||
gpuIdx.writeVInt(0);
|
int cnt = 0;
|
||||||
|
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||||
|
vectors[cnt++] = vectorValues.vectorValue(iter.index());
|
||||||
|
}
|
||||||
|
// TODO: check that the number corresponds for deleted documents
|
||||||
|
assert (vectorValues.size() == cnt);
|
||||||
|
|
||||||
|
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
|
||||||
|
try {
|
||||||
|
writeGPUIndex(fieldInfo.getVectorSimilarityFunction(), vectors);
|
||||||
|
long dataLength = gpuIdx.getFilePointer() - dataOffset;
|
||||||
|
writeMeta(fieldInfo, dataOffset, dataLength);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (Throwable t) {
|
||||||
|
throw new IOException("Failed to write GPU index: ", t);
|
||||||
}
|
}
|
||||||
long dataLength = gpuIdx.getFilePointer() - dataOffset;
|
|
||||||
writeMeta(fieldInfo, dataOffset, dataLength);
|
|
||||||
} else {
|
} else {
|
||||||
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||||
}
|
}
|
||||||
|
@ -157,6 +231,7 @@ public class GPUVectorsWriter extends KnnVectorsWriter {
|
||||||
@Override
|
@Override
|
||||||
public final void close() throws IOException {
|
public final void close() throws IOException {
|
||||||
IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx);
|
IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx);
|
||||||
|
cuVSResources.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -25,7 +25,7 @@ public class GPUVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void beforeClass() {
|
public static void beforeClass() {
|
||||||
assumeTrue("cuvs not supported", GPUVectorsFormat.supported());
|
assumeTrue("cuvs not supported", GPUVectorsFormat.cuVSResourcesOrNull() != null);
|
||||||
}
|
}
|
||||||
|
|
||||||
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat());
|
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue