Improve accuracy of write load forecast when shard numbers change (#129990)

This commit is contained in:
Nick Tindall 2025-06-27 13:04:50 +10:00 committed by GitHub
parent 8acf94b6c3
commit 77b459c454
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 215 additions and 36 deletions

View file

@ -0,0 +1,5 @@
pr: 129990
summary: Make forecast write load accurate when shard numbers change
area: Allocation
type: bug
issues: []

View file

@ -362,7 +362,7 @@ public class DataStreamAutoShardingService {
* <p>If the recommendation is to INCREASE/DECREASE shards the reported cooldown period will be TimeValue.ZERO.
* If the auto sharding service thinks the number of shards must be changed but it can't recommend a change due to the cooldown
* period not lapsing, the result will be of type {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} or
* {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} with the remaining cooldown configured and the number of shards that should
* {@link AutoShardingType#COOLDOWN_PREVENTED_DECREASE} with the remaining cooldown configured and the number of shards that should
* be configured for the data stream once the remaining cooldown lapses as the target number of shards.
*
* <p>The NOT_APPLICABLE type result will report a cooldown period of TimeValue.MAX_VALUE.

View file

@ -108,7 +108,12 @@ class LicensedWriteLoadForecaster implements WriteLoadForecaster {
}
final IndexMetadata writeIndex = metadata.getSafe(dataStream.getWriteIndex());
metadata.put(IndexMetadata.builder(writeIndex).indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble()).build(), false);
metadata.put(
IndexMetadata.builder(writeIndex)
.indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble() / writeIndex.getNumberOfShards())
.build(),
false
);
return metadata;
}
@ -129,11 +134,20 @@ class LicensedWriteLoadForecaster implements WriteLoadForecaster {
}
}
/**
* This calculates the weighted average total write-load for all recent indices.
*
* @param indicesWriteLoadWithinMaxAgeRange The indices considered "recent"
* @return The weighted average total write-load. To get the per-shard write load, this number must be divided by the number of shards
*/
// Visible for testing
static OptionalDouble forecastIndexWriteLoad(List<IndexWriteLoad> indicesWriteLoadWithinMaxAgeRange) {
double totalWeightedWriteLoad = 0;
long totalShardUptime = 0;
double allIndicesWriteLoad = 0;
long allIndicesUptime = 0;
for (IndexWriteLoad writeLoad : indicesWriteLoadWithinMaxAgeRange) {
double totalShardWriteLoad = 0;
long totalShardUptimeInMillis = 0;
long maxShardUptimeInMillis = 0;
for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
@ -141,13 +155,27 @@ class LicensedWriteLoadForecaster implements WriteLoadForecaster {
assert uptimeInMillisForShard.isPresent();
double shardWriteLoad = writeLoadForShard.getAsDouble();
long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
totalWeightedWriteLoad += shardWriteLoad * shardUptimeInMillis;
totalShardUptime += shardUptimeInMillis;
totalShardWriteLoad += shardWriteLoad * shardUptimeInMillis;
totalShardUptimeInMillis += shardUptimeInMillis;
maxShardUptimeInMillis = Math.max(maxShardUptimeInMillis, shardUptimeInMillis);
}
}
double weightedAverageShardWriteLoad = totalShardWriteLoad / totalShardUptimeInMillis;
double totalIndexWriteLoad = weightedAverageShardWriteLoad * writeLoad.numberOfShards();
// We need to weight the contribution from each index somehow, but we only know
// the write-load from the final allocation of each shard at rollover time. It's
// possible the index is much older than any of those shards, but we don't have
// any write-load data beyond their lifetime.
// To avoid making assumptions about periods for which we have no data, we'll weight
// each index's contribution to the forecast by the maximum shard uptime observed in
// that index. It should be safe to extrapolate our weighted average out to the
// maximum uptime observed, based on the assumption that write-load is roughly
// evenly distributed across shards of a datastream index.
allIndicesWriteLoad += totalIndexWriteLoad * maxShardUptimeInMillis;
allIndicesUptime += maxShardUptimeInMillis;
}
return totalShardUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(totalWeightedWriteLoad / totalShardUptime);
return allIndicesUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(allIndicesWriteLoad / allIndicesUptime);
}
@Override

View file

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.writeloadforecaster;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.core.LogEvent;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexMetadataStats;
@ -24,16 +25,19 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matcher;
import org.junit.After;
import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.writeloadforecaster.LicensedWriteLoadForecaster.forecastIndexWriteLoad;
import static org.hamcrest.Matchers.closeTo;
@ -42,6 +46,7 @@ import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
public class LicensedWriteLoadForecasterTests extends ESTestCase {
ThreadPool threadPool;
@ -67,33 +72,15 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
writeLoadForecaster.refreshLicense();
final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
final String dataStreamName = "logs-es";
final int numberOfBackingIndices = 10;
final int numberOfShards = randomIntBetween(1, 5);
final List<Index> backingIndices = new ArrayList<>();
for (int i = 0; i < numberOfBackingIndices; i++) {
final IndexMetadata indexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, i),
numberOfShards,
randomIndexWriteLoad(numberOfShards),
System.currentTimeMillis() - (maxIndexAge.millis() / 2)
);
backingIndices.add(indexMetadata.getIndex());
metadataBuilder.put(indexMetadata, false);
}
final IndexMetadata writeIndexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
numberOfShards,
null,
System.currentTimeMillis()
final ProjectMetadata.Builder metadataBuilder = createMetadataBuilderWithDataStream(
dataStreamName,
numberOfBackingIndices,
randomIntBetween(1, 5),
maxIndexAge
);
backingIndices.add(writeIndexMetadata.getIndex());
metadataBuilder.put(writeIndexMetadata, false);
final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
metadataBuilder.put(dataStream);
final DataStream dataStream = metadataBuilder.dataStream(dataStreamName);
final ProjectMetadata.Builder updatedMetadataBuilder = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStream.getName(),
@ -253,7 +240,7 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(14.4)));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(72.0)));
}
{
@ -264,14 +251,14 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
.withShardWriteLoad(1, 24, 999, 999, 5)
.withShardWriteLoad(2, 24, 999, 999, 5)
.withShardWriteLoad(3, 24, 999, 999, 5)
.withShardWriteLoad(4, 24, 999, 999, 4)
.withShardWriteLoad(4, 24, 999, 999, 5)
.build(),
// Since this shard uptime is really low, it doesn't add much to the avg
IndexWriteLoad.builder(1).withShardWriteLoad(0, 120, 999, 999, 1).build()
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(15.36)));
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(72.59, 0.01)));
}
{
@ -283,7 +270,7 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(12.0)));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(16.0)));
}
{
@ -302,7 +289,7 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(15.83, 0.01)));
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(31.66, 0.01)));
}
}
@ -404,4 +391,163 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
);
}, LicensedWriteLoadForecaster.class, collectingLoggingAssertion);
}
public void testShardIncreaseDoesNotIncreaseTotalLoad() {
testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.INCREASE);
}
public void testShardDecreaseDoesNotDecreaseTotalLoad() {
testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.DECREASE);
}
private void testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange shardCountChange) {
final TimeValue maxIndexAge = TimeValue.timeValueDays(7);
final AtomicBoolean hasValidLicense = new AtomicBoolean(true);
final AtomicInteger licenseCheckCount = new AtomicInteger();
final WriteLoadForecaster writeLoadForecaster = new LicensedWriteLoadForecaster(() -> {
licenseCheckCount.incrementAndGet();
return hasValidLicense.get();
}, threadPool, maxIndexAge);
writeLoadForecaster.refreshLicense();
final String dataStreamName = randomIdentifier();
final ProjectMetadata.Builder originalMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStreamName,
createMetadataBuilderWithDataStream(dataStreamName, randomIntBetween(5, 15), shardCountChange.originalShardCount(), maxIndexAge)
);
// Generate the same data stream, but with a different number of shards in the write index
final ProjectMetadata.Builder changedShardCountMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStreamName,
updateWriteIndexShardCount(dataStreamName, originalMetadata, shardCountChange)
);
IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(originalMetadata.dataStream(dataStreamName).getWriteIndex());
IndexMetadata changedShardCountWriteIndexMetadata = changedShardCountMetadata.getSafe(
changedShardCountMetadata.dataStream(dataStreamName).getWriteIndex()
);
// The shard count changed
assertThat(
changedShardCountWriteIndexMetadata.getNumberOfShards(),
shardCountChange.expectedChangeFromOriginal(originalWriteIndexMetadata.getNumberOfShards())
);
// But the total write-load did not
assertThat(
changedShardCountWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(
changedShardCountWriteIndexMetadata
).getAsDouble(),
closeTo(
originalWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(originalWriteIndexMetadata)
.getAsDouble(),
0.01
)
);
}
public enum ShardCountChange implements IntToIntFunction {
INCREASE(1, 15) {
@Override
public int apply(int originalShardCount) {
return randomIntBetween(originalShardCount + 1, originalShardCount * 3);
}
public Matcher<Integer> expectedChangeFromOriginal(int originalShardCount) {
return greaterThan(originalShardCount);
}
},
DECREASE(10, 30) {
@Override
public int apply(int originalShardCount) {
return randomIntBetween(1, originalShardCount - 1);
}
public Matcher<Integer> expectedChangeFromOriginal(int originalShardCount) {
return lessThan(originalShardCount);
}
};
private final int originalMinimumShardCount;
private final int originalMaximumShardCount;
ShardCountChange(int originalMinimumShardCount, int originalMaximumShardCount) {
this.originalMinimumShardCount = originalMinimumShardCount;
this.originalMaximumShardCount = originalMaximumShardCount;
}
public int originalShardCount() {
return randomIntBetween(originalMinimumShardCount, originalMaximumShardCount);
}
abstract Matcher<Integer> expectedChangeFromOriginal(int originalShardCount);
}
private ProjectMetadata.Builder updateWriteIndexShardCount(
String dataStreamName,
ProjectMetadata.Builder originalMetadata,
ShardCountChange shardCountChange
) {
final ProjectMetadata.Builder updatedShardCountMetadata = ProjectMetadata.builder(originalMetadata.getId());
final DataStream originalDataStream = originalMetadata.dataStream(dataStreamName);
final Index existingWriteIndex = Objects.requireNonNull(originalDataStream.getWriteIndex());
final IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(existingWriteIndex);
// Copy all non-write indices over unchanged
final List<IndexMetadata> backingIndexMetadatas = originalDataStream.getIndices()
.stream()
.filter(index -> index != existingWriteIndex)
.map(originalMetadata::getSafe)
.collect(Collectors.toList());
// Create a new write index with an updated shard count
final IndexMetadata writeIndexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, backingIndexMetadatas.size()),
shardCountChange.apply(originalWriteIndexMetadata.getNumberOfShards()),
null,
System.currentTimeMillis()
);
backingIndexMetadatas.add(writeIndexMetadata);
backingIndexMetadatas.forEach(indexMetadata -> updatedShardCountMetadata.put(indexMetadata, false));
final DataStream dataStream = createDataStream(
dataStreamName,
backingIndexMetadatas.stream().map(IndexMetadata::getIndex).toList()
);
updatedShardCountMetadata.put(dataStream);
return updatedShardCountMetadata;
}
private ProjectMetadata.Builder createMetadataBuilderWithDataStream(
String dataStreamName,
int numberOfBackingIndices,
int numberOfShards,
TimeValue maxIndexAge
) {
final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
final List<Index> backingIndices = new ArrayList<>();
for (int i = 0; i < numberOfBackingIndices; i++) {
final IndexMetadata indexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, i),
numberOfShards,
randomIndexWriteLoad(numberOfShards),
System.currentTimeMillis() - (maxIndexAge.millis() / 2)
);
backingIndices.add(indexMetadata.getIndex());
metadataBuilder.put(indexMetadata, false);
}
final IndexMetadata writeIndexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
numberOfShards,
null,
System.currentTimeMillis()
);
backingIndices.add(writeIndexMetadata.getIndex());
metadataBuilder.put(writeIndexMetadata, false);
final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
metadataBuilder.put(dataStream);
return metadataBuilder;
}
}