From ad0fe78e3e9d866f7eb97c7d73c67db321c01ff1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 23 Apr 2025 12:34:20 -0400 Subject: [PATCH] Add docIds writer for future vector format usage (#127185) --- .../index/codec/vectors/DocIdsWriter.java | 376 ++++++++++++++++++ .../codec/vectors/DocIdsWriterTests.java | 154 +++++++ 2 files changed, 530 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java new file mode 100644 index 000000000000..ca4bbf0377e0 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DocIdsWriter.java @@ -0,0 +1,376 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.index.PointValues.IntersectVisitor; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.DataOutput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.DocBaseBitSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.IntsRef; +import org.apache.lucene.util.LongsRef; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +import java.io.IOException; +import java.util.Arrays; + +/** + * This class is used to write and read the doc ids in a compressed format. The format is optimized + * for the number of bits per value (bpv) and the number of values. + * + *

It is copied from the BKD implementation. + */ +final class DocIdsWriter { + public static final int DEFAULT_MAX_POINTS_IN_LEAF_NODE = 512; + + private static final byte CONTINUOUS_IDS = (byte) -2; + private static final byte BITSET_IDS = (byte) -1; + private static final byte DELTA_BPV_16 = (byte) 16; + private static final byte BPV_21 = (byte) 21; + private static final byte BPV_24 = (byte) 24; + private static final byte BPV_32 = (byte) 32; + + private int[] scratch = new int[0]; + private final LongsRef scratchLongs = new LongsRef(); + + /** + * IntsRef to be used to iterate over the scratch buffer. A single instance is reused to avoid + * re-allocating the object. The ints and length fields need to be reset each use. + * + *

The main reason for existing is to be able to call the {@link + * IntersectVisitor#visit(IntsRef)} method rather than the {@link IntersectVisitor#visit(int)} + * method. This seems to make a difference in performance, probably due to fewer virtual calls + * then happening (once per read call rather than once per doc). + */ + private final IntsRef scratchIntsRef = new IntsRef(); + + { + // This is here to not rely on the default constructor of IntsRef to set offset to 0 + scratchIntsRef.offset = 0; + } + + DocIdsWriter() {} + + void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException { + // docs can be sorted either when all docs in a block have the same value + // or when a segment is sorted + if (count == 0) { + out.writeByte(CONTINUOUS_IDS); + return; + } + if (count > scratch.length) { + scratch = new int[count]; + } + boolean strictlySorted = true; + int min = docIds.apply(0); + int max = min; + for (int i = 1; i < count; ++i) { + int last = docIds.apply(i - 1); + int current = docIds.apply(i); + if (last >= current) { + strictlySorted = false; + } + min = Math.min(min, current); + max = Math.max(max, current); + } + + int min2max = max - min + 1; + if (strictlySorted) { + if (min2max == count) { + // continuous ids, typically happens when segment is sorted + out.writeByte(CONTINUOUS_IDS); + out.writeVInt(docIds.apply(0)); + return; + } else if (min2max <= (count << 4)) { + assert min2max > count : "min2max: " + min2max + ", count: " + count; + // Only trigger bitset optimization when max - min + 1 <= 16 * count in order to avoid + // expanding too much storage. + // A field with lower cardinality will have higher probability to trigger this optimization. + out.writeByte(BITSET_IDS); + writeIdsAsBitSet(docIds, count, out); + return; + } + } + + if (min2max <= 0xFFFF) { + out.writeByte(DELTA_BPV_16); + for (int i = 0; i < count; i++) { + scratch[i] = docIds.apply(i) - min; + } + out.writeVInt(min); + final int halfLen = count >> 1; + for (int i = 0; i < halfLen; ++i) { + scratch[i] = scratch[halfLen + i] | (scratch[i] << 16); + } + for (int i = 0; i < halfLen; i++) { + out.writeInt(scratch[i]); + } + if ((count & 1) == 1) { + out.writeShort((short) scratch[count - 1]); + } + } else { + if (max <= 0x1FFFFF) { + out.writeByte(BPV_21); + final int oneThird = floorToMultipleOf16(count / 3); + final int numInts = oneThird * 2; + for (int i = 0; i < numInts; i++) { + scratch[i] = docIds.apply(i) << 11; + } + for (int i = 0; i < oneThird; i++) { + final int longIdx = i + numInts; + scratch[i] |= docIds.apply(longIdx) & 0x7FF; + scratch[i + oneThird] |= (docIds.apply(longIdx) >>> 11) & 0x7FF; + } + for (int i = 0; i < numInts; i++) { + out.writeInt(scratch[i]); + } + int i = oneThird * 3; + for (; i < count - 2; i += 3) { + out.writeLong(((long) docIds.apply(i)) | (((long) docIds.apply(i + 1)) << 21) | (((long) docIds.apply(i + 2)) << 42)); + } + for (; i < count; ++i) { + out.writeShort((short) docIds.apply(i)); + out.writeByte((byte) (docIds.apply(i) >>> 16)); + } + } else if (max <= 0xFFFFFF) { + out.writeByte(BPV_24); + + // encode the docs in the format that can be vectorized decoded. + final int quarter = count >> 2; + final int numInts = quarter * 3; + for (int i = 0; i < numInts; i++) { + scratch[i] = docIds.apply(i) << 8; + } + for (int i = 0; i < quarter; i++) { + final int longIdx = i + numInts; + scratch[i] |= docIds.apply(longIdx) & 0xFF; + scratch[i + quarter] |= (docIds.apply(longIdx) >>> 8) & 0xFF; + scratch[i + quarter * 2] |= docIds.apply(longIdx) >>> 16; + } + for (int i = 0; i < numInts; i++) { + out.writeInt(scratch[i]); + } + for (int i = quarter << 2; i < count; ++i) { + out.writeShort((short) docIds.apply(i)); + out.writeByte((byte) (docIds.apply(i) >>> 16)); + } + } else { + out.writeByte(BPV_32); + for (int i = 0; i < count; i++) { + out.writeInt(docIds.apply(i)); + } + } + } + } + + private static void writeIdsAsBitSet(IntToIntFunction docIds, int count, DataOutput out) throws IOException { + int min = docIds.apply(0); + int max = docIds.apply(count - 1); + + final int offsetWords = min >> 6; + final int offsetBits = offsetWords << 6; + final int totalWordCount = FixedBitSet.bits2words(max - offsetBits + 1); + long currentWord = 0; + int currentWordIndex = 0; + + out.writeVInt(offsetWords); + out.writeVInt(totalWordCount); + // build bit set streaming + for (int i = 0; i < count; i++) { + final int index = docIds.apply(i) - offsetBits; + final int nextWordIndex = index >> 6; + assert currentWordIndex <= nextWordIndex; + if (currentWordIndex < nextWordIndex) { + out.writeLong(currentWord); + currentWord = 0L; + currentWordIndex++; + while (currentWordIndex < nextWordIndex) { + currentWordIndex++; + out.writeLong(0L); + } + } + currentWord |= 1L << index; + } + out.writeLong(currentWord); + assert currentWordIndex + 1 == totalWordCount; + } + + /** Read {@code count} integers into {@code docIDs}. */ + void readInts(IndexInput in, int count, int[] docIDs) throws IOException { + if (count > scratch.length) { + scratch = new int[count]; + } + final int bpv = in.readByte(); + switch (bpv) { + case CONTINUOUS_IDS: + readContinuousIds(in, count, docIDs); + break; + case BITSET_IDS: + readBitSet(in, count, docIDs); + break; + case DELTA_BPV_16: + readDelta16(in, count, docIDs); + break; + case BPV_21: + readInts21(in, count, docIDs); + break; + case BPV_24: + readInts24(in, count, docIDs); + break; + case BPV_32: + readInts32(in, count, docIDs); + break; + default: + throw new IOException("Unsupported number of bits per value: " + bpv); + } + } + + private DocIdSetIterator readBitSetIterator(IndexInput in, int count) throws IOException { + int offsetWords = in.readVInt(); + int longLen = in.readVInt(); + scratchLongs.longs = ArrayUtil.growNoCopy(scratchLongs.longs, longLen); + in.readLongs(scratchLongs.longs, 0, longLen); + // make ghost bits clear for FixedBitSet. + if (longLen < scratchLongs.length) { + Arrays.fill(scratchLongs.longs, longLen, scratchLongs.longs.length, 0); + } + scratchLongs.length = longLen; + FixedBitSet bitSet = new FixedBitSet(scratchLongs.longs, longLen << 6); + return new DocBaseBitSetIterator(bitSet, count, offsetWords << 6); + } + + private static void readContinuousIds(IndexInput in, int count, int[] docIDs) throws IOException { + int start = in.readVInt(); + for (int i = 0; i < count; i++) { + docIDs[i] = start + i; + } + } + + private void readBitSet(IndexInput in, int count, int[] docIDs) throws IOException { + DocIdSetIterator iterator = readBitSetIterator(in, count); + int docId, pos = 0; + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + docIDs[pos++] = docId; + } + assert pos == count : "pos: " + pos + ", count: " + count; + } + + private static void readDelta16(IndexInput in, int count, int[] docIds) throws IOException { + final int min = in.readVInt(); + final int half = count >> 1; + in.readInts(docIds, 0, half); + if (count == DEFAULT_MAX_POINTS_IN_LEAF_NODE) { + // Same format, but enabling the JVM to specialize the decoding logic for the default number + // of points per node proved to help on benchmarks + decode16(docIds, DEFAULT_MAX_POINTS_IN_LEAF_NODE / 2, min); + } else { + decode16(docIds, half, min); + } + // read the remaining doc if count is odd. + for (int i = half << 1; i < count; i++) { + docIds[i] = Short.toUnsignedInt(in.readShort()) + min; + } + } + + private static void decode16(int[] docIDs, int half, int min) { + for (int i = 0; i < half; ++i) { + final int l = docIDs[i]; + docIDs[i] = (l >>> 16) + min; + docIDs[i + half] = (l & 0xFFFF) + min; + } + } + + private static int floorToMultipleOf16(int n) { + assert n >= 0; + return n & 0xFFFFFFF0; + } + + private void readInts21(IndexInput in, int count, int[] docIDs) throws IOException { + int oneThird = floorToMultipleOf16(count / 3); + int numInts = oneThird << 1; + in.readInts(scratch, 0, numInts); + if (count == DEFAULT_MAX_POINTS_IN_LEAF_NODE) { + // Same format, but enabling the JVM to specialize the decoding logic for the default number + // of points per node proved to help on benchmarks + decode21( + docIDs, + scratch, + floorToMultipleOf16(DEFAULT_MAX_POINTS_IN_LEAF_NODE / 3), + floorToMultipleOf16(DEFAULT_MAX_POINTS_IN_LEAF_NODE / 3) * 2 + ); + } else { + decode21(docIDs, scratch, oneThird, numInts); + } + int i = oneThird * 3; + for (; i < count - 2; i += 3) { + long l = in.readLong(); + docIDs[i] = (int) (l & 0x1FFFFFL); + docIDs[i + 1] = (int) ((l >>> 21) & 0x1FFFFFL); + docIDs[i + 2] = (int) (l >>> 42); + } + for (; i < count; ++i) { + docIDs[i] = (in.readShort() & 0xFFFF) | (in.readByte() & 0xFF) << 16; + } + } + + private static void decode21(int[] docIds, int[] scratch, int oneThird, int numInts) { + for (int i = 0; i < numInts; ++i) { + docIds[i] = scratch[i] >>> 11; + } + for (int i = 0; i < oneThird; i++) { + docIds[i + numInts] = (scratch[i] & 0x7FF) | ((scratch[i + oneThird] & 0x7FF) << 11); + } + } + + private void readInts24(IndexInput in, int count, int[] docIDs) throws IOException { + int quarter = count >> 2; + int numInts = quarter * 3; + in.readInts(scratch, 0, numInts); + if (count == DEFAULT_MAX_POINTS_IN_LEAF_NODE) { + // Same format, but enabling the JVM to specialize the decoding logic for the default number + // of points per node proved to help on benchmarks + assert floorToMultipleOf16(quarter) == quarter + : "We are relying on the fact that quarter of DEFAULT_MAX_POINTS_IN_LEAF_NODE" + + " is a multiple of 16 to vectorize the decoding loop," + + " please check performance issue if you want to break this assumption."; + decode24(docIDs, scratch, DEFAULT_MAX_POINTS_IN_LEAF_NODE / 4, DEFAULT_MAX_POINTS_IN_LEAF_NODE / 4 * 3); + } else { + decode24(docIDs, scratch, quarter, numInts); + } + // Now read the remaining 0, 1, 2 or 3 values + for (int i = quarter << 2; i < count; ++i) { + docIDs[i] = (in.readShort() & 0xFFFF) | (in.readByte() & 0xFF) << 16; + } + } + + private static void decode24(int[] docIDs, int[] scratch, int quarter, int numInts) { + for (int i = 0; i < numInts; ++i) { + docIDs[i] = scratch[i] >>> 8; + } + for (int i = 0; i < quarter; i++) { + docIDs[i + numInts] = (scratch[i] & 0xFF) | ((scratch[i + quarter] & 0xFF) << 8) | ((scratch[i + quarter * 2] & 0xFF) << 16); + } + } + + private static void readInts32(IndexInput in, int count, int[] docIDs) throws IOException { + in.readInts(docIDs, 0, count); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java new file mode 100644 index 000000000000..17270abdbf76 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/DocIdsWriterTests.java @@ -0,0 +1,154 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.CollectionUtil; +import org.apache.lucene.util.Constants; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Set; + +import static org.elasticsearch.index.codec.vectors.DocIdsWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE; + +public class DocIdsWriterTests extends LuceneTestCase { + + public void testRandom() throws Exception { + int numIters = atLeast(100); + try (Directory dir = newDirectory()) { + for (int iter = 0; iter < numIters; ++iter) { + int count = random().nextBoolean() ? 1 + random().nextInt(5000) : DEFAULT_MAX_POINTS_IN_LEAF_NODE; + int[] docIDs = new int[count]; + final int bpv = TestUtil.nextInt(random(), 1, 32); + for (int i = 0; i < docIDs.length; ++i) { + docIDs[i] = TestUtil.nextInt(random(), 0, (1 << bpv) - 1); + } + test(dir, docIDs); + } + } + } + + public void testSorted() throws Exception { + int numIters = atLeast(100); + try (Directory dir = newDirectory()) { + for (int iter = 0; iter < numIters; ++iter) { + int[] docIDs = new int[1 + random().nextInt(5000)]; + final int bpv = TestUtil.nextInt(random(), 1, 32); + for (int i = 0; i < docIDs.length; ++i) { + docIDs[i] = TestUtil.nextInt(random(), 0, (1 << bpv) - 1); + } + Arrays.sort(docIDs); + test(dir, docIDs); + } + } + } + + public void testCluster() throws Exception { + int numIters = atLeast(100); + try (Directory dir = newDirectory()) { + for (int iter = 0; iter < numIters; ++iter) { + int count = random().nextBoolean() ? 1 + random().nextInt(5000) : DEFAULT_MAX_POINTS_IN_LEAF_NODE; + int[] docIDs = new int[count]; + int min = random().nextInt(1000); + final int bpv = TestUtil.nextInt(random(), 1, 16); + for (int i = 0; i < docIDs.length; ++i) { + docIDs[i] = min + TestUtil.nextInt(random(), 0, (1 << bpv) - 1); + } + test(dir, docIDs); + } + } + } + + public void testBitSet() throws Exception { + int numIters = atLeast(100); + try (Directory dir = newDirectory()) { + for (int iter = 0; iter < numIters; ++iter) { + int size = 1 + random().nextInt(5000); + Set set = CollectionUtil.newHashSet(size); + int small = random().nextInt(1000); + while (set.size() < size) { + set.add(small + random().nextInt(size * 16)); + } + int[] docIDs = set.stream().mapToInt(t -> t).sorted().toArray(); + test(dir, docIDs); + } + } + } + + public void testContinuousIds() throws Exception { + int numIters = atLeast(100); + try (Directory dir = newDirectory()) { + for (int iter = 0; iter < numIters; ++iter) { + int size = 1 + random().nextInt(5000); + int[] docIDs = new int[size]; + int start = random().nextInt(1000000); + for (int i = 0; i < docIDs.length; i++) { + docIDs[i] = start + i; + } + test(dir, docIDs); + } + } + } + + private void test(Directory dir, int[] ints) throws Exception { + final long len; + // It is hard to get BPV24-encoded docs in TextLuceneXXPointsFormat, test bwc here as well. + DocIdsWriter docIdsWriter = new DocIdsWriter(); + try (IndexOutput out = dir.createOutput("tmp", IOContext.DEFAULT)) { + docIdsWriter.writeDocIds(i -> ints[i], ints.length, out); + len = out.getFilePointer(); + if (random().nextBoolean()) { + out.writeLong(0); // garbage + } + } + try (IndexInput in = dir.openInput("tmp", IOContext.READONCE)) { + int[] read = new int[ints.length]; + docIdsWriter.readInts(in, ints.length, read); + assertArrayEquals(ints, read); + assertEquals(len, in.getFilePointer()); + } + dir.deleteFile("tmp"); + } + + // This simple test tickles a JVM C2 JIT crash on JDK's less than 21.0.1 + // Crashes only when run with HotSpot C2. + // Regardless of whether C2 is enabled or not, the test should never fail. + public void testCrash() throws IOException { + assumeTrue("Requires HotSpot C2 compiler (won't work on client VM).", Constants.IS_HOTSPOT_VM && (Constants.IS_CLIENT_VM == false)); + int itrs = atLeast(100); + for (int i = 0; i < itrs; i++) { + try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null))) { + for (int d = 0; d < 20_000; d++) { + iw.addDocument(List.of(new IntPoint("foo", 0), new SortedNumericDocValuesField("bar", 0))); + } + } + } + } +}