Optimize memory usage in ShardBulkInferenceActionFilter (#124313)

This refactor improves memory efficiency by processing inference requests in batches, capped by a max input length.

Changes include:
- A new dynamic operator setting to control the maximum batch size in bytes.
- Dropping input data from inference responses when the legacy semantic text format isn’t used, saving memory.
- Clearing inference results dynamically after each bulk item to free up memory sooner.

This is a step toward enabling circuit breakers to better handle memory usage when dealing with large inputs.
This commit is contained in:
Jim Ferenczi 2025-03-14 09:51:03 +00:00 committed by GitHub
parent 35ecbf6e87
commit 361b51d436
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 316 additions and 248 deletions

View file

@ -0,0 +1,5 @@
pr: 124313
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
area: Search
type: enhancement
issues: []

View file

@ -20,6 +20,7 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.SourceFieldMapper;
@ -44,6 +45,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -85,7 +87,12 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
@Override @Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
return Settings.builder()
.put(otherSettings)
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
.build();
} }
@Override @Override

View file

@ -142,6 +142,7 @@ import java.util.function.Predicate;
import java.util.function.Supplier; import java.util.function.Supplier;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
public class InferencePlugin extends Plugin public class InferencePlugin extends Plugin
@ -442,6 +443,7 @@ public class InferencePlugin extends Plugin
settings.addAll(Truncator.getSettingsDefinitions()); settings.addAll(Truncator.getSettingsDefinitions());
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions()); settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START); settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
return settings; return settings;

View file

@ -25,7 +25,11 @@ import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasable;
@ -43,6 +47,10 @@ import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.InferenceException; import org.elasticsearch.xpack.inference.InferenceException;
@ -63,6 +71,8 @@ import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
/** /**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
@ -72,10 +82,23 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FE
* This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the
* results during indexing on the shard. * results during indexing on the shard.
* *
* TODO: batchSize should be configurable via a cluster setting
*/ */
public class ShardBulkInferenceActionFilter implements MappedActionFilter { public class ShardBulkInferenceActionFilter implements MappedActionFilter {
protected static final int DEFAULT_BATCH_SIZE = 512; private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
/**
* Defines the cumulative size limit of input data before triggering a batch inference call.
* This setting controls how much data can be accumulated before an inference request is sent in batch.
*/
public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
"indices.inference.batch_size",
DEFAULT_BATCH_SIZE,
ByteSizeValue.ONE,
ByteSizeValue.ofMb(100),
Setting.Property.NodeScope,
Setting.Property.OperatorDynamic
);
private static final Object EXPLICIT_NULL = new Object(); private static final Object EXPLICIT_NULL = new Object();
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference(); private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();
@ -83,29 +106,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final InferenceServiceRegistry inferenceServiceRegistry; private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry; private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState; private final XPackLicenseState licenseState;
private final int batchSize; private volatile long batchSizeInBytes;
public ShardBulkInferenceActionFilter( public ShardBulkInferenceActionFilter(
ClusterService clusterService, ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry, InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry, ModelRegistry modelRegistry,
XPackLicenseState licenseState XPackLicenseState licenseState
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
}
public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
int batchSize
) { ) {
this.clusterService = clusterService; this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry; this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry; this.modelRegistry = modelRegistry;
this.licenseState = licenseState; this.licenseState = licenseState;
this.batchSize = batchSize; this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
}
private void setBatchSize(ByteSizeValue newBatchSize) {
batchSizeInBytes = newBatchSize.getBytes();
} }
@Override @Override
@ -148,14 +166,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
/** /**
* A field inference request on a single input. * A field inference request on a single input.
* @param index The index of the request in the original bulk request. * @param bulkItemIndex The index of the item in the original bulk request.
* @param field The target field. * @param field The target field.
* @param sourceField The source field. * @param sourceField The source field.
* @param input The input to run inference on. * @param input The input to run inference on.
* @param inputOrder The original order of the input. * @param inputOrder The original order of the input.
* @param offsetAdjustment The adjustment to apply to the chunk text offsets. * @param offsetAdjustment The adjustment to apply to the chunk text offsets.
*/ */
private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {} private record FieldInferenceRequest(
int bulkItemIndex,
String field,
String sourceField,
String input,
int inputOrder,
int offsetAdjustment
) {}
/** /**
* The field inference response. * The field inference response.
@ -218,29 +243,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
@Override @Override
public void run() { public void run() {
Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); executeNext(0);
}
private void executeNext(int itemOffset) {
if (itemOffset >= bulkShardRequest.items().length) {
onCompletion.run();
return;
}
var items = bulkShardRequest.items();
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
long totalInputLength = 0;
int itemIndex = itemOffset;
while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
var item = items[itemIndex];
totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
itemIndex += 1;
}
int nextItemOffset = itemIndex;
Runnable onInferenceCompletion = () -> { Runnable onInferenceCompletion = () -> {
try { try {
for (var inferenceResponse : inferenceResults.asList()) { for (int i = itemOffset; i < nextItemOffset; i++) {
var request = bulkShardRequest.items()[inferenceResponse.id]; var result = inferenceResults.get(i);
try { if (result == null) {
applyInferenceResponses(request, inferenceResponse); continue;
} catch (Exception exc) {
request.abort(bulkShardRequest.index(), exc);
} }
var item = items[i];
try {
applyInferenceResponses(item, result);
} catch (Exception exc) {
item.abort(bulkShardRequest.index(), exc);
}
// we don't need to keep the inference results around
inferenceResults.set(i, null);
} }
} finally { } finally {
onCompletion.run(); executeNext(nextItemOffset);
} }
}; };
try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
for (var entry : inferenceRequests.entrySet()) { for (var entry : fieldRequestsMap.entrySet()) {
executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
} }
} }
} }
private void executeShardBulkInferenceAsync( private void executeChunkedInferenceAsync(
final String inferenceId, final String inferenceId,
@Nullable InferenceProvider inferenceProvider, @Nullable InferenceProvider inferenceProvider,
final List<FieldInferenceRequest> requests, final List<FieldInferenceRequest> requests,
@ -262,11 +312,11 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
unparsedModel.secrets() unparsedModel.secrets()
) )
); );
executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
} else { } else {
try (onFinish) { try (onFinish) {
for (FieldInferenceRequest request : requests) { for (FieldInferenceRequest request : requests) {
inferenceResults.get(request.index).failures.add( inferenceResults.get(request.bulkItemIndex).failures.add(
new ResourceNotFoundException( new ResourceNotFoundException(
"Inference service [{}] not found for field [{}]", "Inference service [{}] not found for field [{}]",
unparsedModel.service(), unparsedModel.service(),
@ -297,7 +347,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
request.field request.field
); );
} }
inferenceResults.get(request.index).failures.add(failure); inferenceResults.get(request.bulkItemIndex).failures.add(failure);
} }
} }
} }
@ -305,18 +355,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
return; return;
} }
int currentBatchSize = Math.min(requests.size(), batchSize); final List<String> inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() { ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
@Override @Override
public void onResponse(List<ChunkedInference> results) { public void onResponse(List<ChunkedInference> results) {
try { try (onFinish) {
var requestsIterator = requests.iterator(); var requestsIterator = requests.iterator();
for (ChunkedInference result : results) { for (ChunkedInference result : results) {
var request = requestsIterator.next(); var request = requestsIterator.next();
var acc = inferenceResults.get(request.index); var acc = inferenceResults.get(request.bulkItemIndex);
if (result instanceof ChunkedInferenceError error) { if (result instanceof ChunkedInferenceError error) {
acc.addFailure( acc.addFailure(
new InferenceException( new InferenceException(
@ -331,7 +378,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
new FieldInferenceResponse( new FieldInferenceResponse(
request.field(), request.field(),
request.sourceField(), request.sourceField(),
request.input(), useLegacyFormat ? request.input() : null,
request.inputOrder(), request.inputOrder(),
request.offsetAdjustment(), request.offsetAdjustment(),
inferenceProvider.model, inferenceProvider.model,
@ -340,17 +387,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
); );
} }
} }
} finally {
onFinish();
} }
} }
@Override @Override
public void onFailure(Exception exc) { public void onFailure(Exception exc) {
try { try (onFinish) {
for (FieldInferenceRequest request : requests) { for (FieldInferenceRequest request : requests) {
addInferenceResponseFailure( addInferenceResponseFailure(
request.index, request.bulkItemIndex,
new InferenceException( new InferenceException(
"Exception when running inference id [{}] on field [{}]", "Exception when running inference id [{}] on field [{}]",
exc, exc,
@ -359,16 +404,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
) )
); );
} }
} finally {
onFinish();
}
}
private void onFinish() {
if (nextBatch.isEmpty()) {
onFinish.close();
} else {
executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish);
} }
} }
}; };
@ -376,115 +411,17 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
.chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener); .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener);
} }
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
if (acc == null) {
acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>());
inferenceResults.set(id, acc);
}
return acc;
}
private void addInferenceResponseFailure(int id, Exception failure) {
var acc = ensureResponseAccumulatorSlot(id);
acc.addFailure(failure);
}
/** /**
* Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
* If the response contains failures, the bulk item request is marked as failed for the downstream action. * for the specified {@code item}.
* Otherwise, the source of the request is augmented with the field inference results. *
* @param item The bulk request item to process.
* @param itemIndex The position of the item within the original bulk request.
* @param requestsMap A map storing inference requests, where each key is an inference ID,
* and the value is a list of associated {@link FieldInferenceRequest} objects.
* @return The total content length of all newly added requests, or {@code 0} if no requests were added.
*/ */
private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException { private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
if (response.failures().isEmpty() == false) {
for (var failure : response.failures()) {
item.abort(item.index(), failure);
}
return;
}
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
var newDocMap = indexRequest.sourceAsMap();
Map<String, Object> inferenceFieldsMap = new HashMap<>();
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
var responses = entry.getValue();
Model model = null;
InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
if (inferenceFieldMetadata == null) {
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
}
// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
for (var resp : responses) {
// Get the first non-null model from the response list
if (model == null) {
model = resp.model;
}
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
lst.addAll(
SemanticTextField.toSemanticTextFieldChunks(
resp.input,
resp.offsetAdjustment,
resp.chunkedResults,
indexRequest.getContentType(),
useLegacyFormat
)
);
}
List<String> inputs = responses.stream()
.filter(r -> r.sourceField().equals(fieldName))
.map(r -> r.input)
.collect(Collectors.toList());
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
// merge in the field's existing model settings on the data node.
var result = new SemanticTextField(
useLegacyFormat,
fieldName,
useLegacyFormat ? inputs : null,
new SemanticTextField.InferenceResult(
inferenceFieldMetadata.getInferenceId(),
model != null ? new MinimalServiceSettings(model) : null,
chunkMap
),
indexRequest.getContentType()
);
if (useLegacyFormat) {
SemanticTextUtils.insertValue(fieldName, newDocMap, result);
} else {
inferenceFieldsMap.put(fieldName, result);
}
}
if (useLegacyFormat == false) {
newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap);
}
indexRequest.source(newDocMap, indexRequest.getContentType());
}
/**
* Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index.
* If results are already populated for fields in the original index request, the inference request for this specific
* field is skipped, and the existing results remain unchanged.
* Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing,
* where an error will be thrown if they mismatch or if the content is malformed.
* <p>
* TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ?
*/
private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) {
var item = bulkShardRequest.items()[itemIndex];
if (item.getPrimaryResponse() != null) {
// item was already aborted/processed by a filter in the chain upstream (e.g. security)
continue;
}
boolean isUpdateRequest = false; boolean isUpdateRequest = false;
final IndexRequest indexRequest; final IndexRequest indexRequest;
if (item.request() instanceof IndexRequest ir) { if (item.request() instanceof IndexRequest ir) {
@ -500,15 +437,16 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
SemanticTextFieldMapper.CONTENT_TYPE SemanticTextFieldMapper.CONTENT_TYPE
) )
); );
continue; return 0;
} }
indexRequest = updateRequest.doc(); indexRequest = updateRequest.doc();
} else { } else {
// ignore delete request // ignore delete request
continue; return 0;
} }
final Map<String, Object> docMap = indexRequest.sourceAsMap(); final Map<String, Object> docMap = indexRequest.sourceAsMap();
long inputLength = 0;
for (var entry : fieldInferenceMap.values()) { for (var entry : fieldInferenceMap.values()) {
String field = entry.getName(); String field = entry.getName();
String inferenceId = entry.getInferenceId(); String inferenceId = entry.getInferenceId();
@ -577,15 +515,16 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
break; break;
} }
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0; int offsetAdjustment = 0;
for (String v : values) { for (String v : values) {
inputLength += v.length();
if (v.isBlank()) { if (v.isBlank()) {
slot.addOrUpdateResponse( slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
); );
} else { } else {
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
} }
// When using the inference metadata fields format, all the input values are concatenated so that the // When using the inference metadata fields format, all the input values are concatenated so that the
@ -595,9 +534,127 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
} }
} }
} }
return inputLength;
} }
return fieldRequestsMap;
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
if (acc == null) {
acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>());
inferenceResults.set(id, acc);
} }
return acc;
}
private void addInferenceResponseFailure(int id, Exception failure) {
var acc = ensureResponseAccumulatorSlot(id);
acc.addFailure(failure);
}
/**
* Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}.
* If the response contains failures, the bulk item request is marked as failed for the downstream action.
* Otherwise, the source of the request is augmented with the field inference results.
*/
private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException {
if (response.failures().isEmpty() == false) {
for (var failure : response.failures()) {
item.abort(item.index(), failure);
}
return;
}
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
Map<String, Object> inferenceFieldsMap = new HashMap<>();
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
var responses = entry.getValue();
Model model = null;
InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
if (inferenceFieldMetadata == null) {
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
}
// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
for (var resp : responses) {
// Get the first non-null model from the response list
if (model == null) {
model = resp.model;
}
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
var chunks = useLegacyFormat
? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType())
: toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType());
lst.addAll(chunks);
}
List<String> inputs = useLegacyFormat
? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList())
: null;
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
// merge in the field's existing model settings on the data node.
var result = new SemanticTextField(
useLegacyFormat,
fieldName,
inputs,
new SemanticTextField.InferenceResult(
inferenceFieldMetadata.getInferenceId(),
model != null ? new MinimalServiceSettings(model) : null,
chunkMap
),
indexRequest.getContentType()
);
inferenceFieldsMap.put(fieldName, result);
}
if (useLegacyFormat) {
var newDocMap = indexRequest.sourceAsMap();
for (var entry : inferenceFieldsMap.entrySet()) {
SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue());
}
indexRequest.source(newDocMap, indexRequest.getContentType());
} else {
try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap);
indexRequest.source(builder);
}
}
}
}
/**
* Appends the original source and the new inference metadata field directly to the provided
* {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
*/
private static void appendSourceAndInferenceMetadata(
XContentBuilder builder,
BytesReference source,
XContentType xContentType,
Map<String, Object> inferenceFieldsMap
) throws IOException {
builder.startObject();
// append the original source
try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) {
// skip start object
parser.nextToken();
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
builder.copyCurrentStructure(parser);
}
}
// add the inference metadata field
builder.field(InferenceMetadataFieldsMapper.NAME);
try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
builder.copyCurrentStructure(parser);
}
builder.endObject();
} }
static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) { static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {

View file

@ -267,37 +267,38 @@ public record SemanticTextField(
/** /**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/ */
public static List<Chunk> toSemanticTextFieldChunks( public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
String input, throws IOException {
int offsetAdjustment,
ChunkedInference results,
XContentType contentType,
boolean useLegacyFormat
) throws IOException {
List<Chunk> chunks = new ArrayList<>(); List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent()); Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) { while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat)); chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
} }
return chunks; return chunks;
} }
public static Chunk toSemanticTextFieldChunk( /**
String input, * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
int offsetAdjustment, */
ChunkedInference.Chunk chunk, public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
boolean useLegacyFormat
) {
String text = null; String text = null;
int startOffset = -1; int startOffset = chunk.textOffset().start() + offsetAdjustment;
int endOffset = -1; int endOffset = chunk.textOffset().end() + offsetAdjustment;
if (useLegacyFormat) {
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
} else {
startOffset = chunk.textOffset().start() + offsetAdjustment;
endOffset = chunk.textOffset().end() + offsetAdjustment;
}
return new Chunk(text, startOffset, endOffset, chunk.bytesReference()); return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
} }
public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
}
return chunks;
}
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
return new Chunk(text, -1, -1, chunk.bytesReference());
}
} }

View file

@ -28,7 +28,9 @@ import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersion;
@ -66,12 +68,13 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
@ -118,7 +121,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception { public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true); ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1); CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> { ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try { try {
@ -144,7 +147,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() throws InterruptedException { public void testLicenseInvalidForInference() throws InterruptedException {
StaticModel model = StaticModel.createRandomInstance(); StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
CountDownLatch chainExecuted = new CountDownLatch(1); CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> { ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try { try {
@ -185,7 +188,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter( ShardBulkInferenceActionFilter filter = createFilter(
threadPool, threadPool,
Map.of(model.getInferenceEntityId(), model), Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat, useLegacyFormat,
true true
); );
@ -232,7 +234,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter( ShardBulkInferenceActionFilter filter = createFilter(
threadPool, threadPool,
Map.of(model.getInferenceEntityId(), model), Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat, useLegacyFormat,
true true
); );
@ -303,7 +304,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter( ShardBulkInferenceActionFilter filter = createFilter(
threadPool, threadPool,
Map.of(model.getInferenceEntityId(), model), Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat, useLegacyFormat,
true true
); );
@ -374,7 +374,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter( ShardBulkInferenceActionFilter filter = createFilter(
threadPool, threadPool,
Map.of(model.getInferenceEntityId(), model), Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat, useLegacyFormat,
true true
); );
@ -447,13 +446,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
modifiedRequests[id] = res[1]; modifiedRequests[id] = res[1];
} }
ShardBulkInferenceActionFilter filter = createFilter( ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
threadPool,
inferenceModelMap,
randomIntBetween(10, 30),
useLegacyFormat,
true
);
CountDownLatch chainExecuted = new CountDownLatch(1); CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> { ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try { try {
@ -487,7 +480,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
private static ShardBulkInferenceActionFilter createFilter( private static ShardBulkInferenceActionFilter createFilter(
ThreadPool threadPool, ThreadPool threadPool,
Map<String, StaticModel> modelMap, Map<String, StaticModel> modelMap,
int batchSize,
boolean useLegacyFormat, boolean useLegacyFormat,
boolean isLicenseValidForInference boolean isLicenseValidForInference
) { ) {
@ -554,18 +546,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
createClusterService(useLegacyFormat), createClusterService(useLegacyFormat),
inferenceServiceRegistry, inferenceServiceRegistry,
modelRegistry, modelRegistry,
licenseState, licenseState
batchSize
); );
} }
private static ClusterService createClusterService(boolean useLegacyFormat) { private static ClusterService createClusterService(boolean useLegacyFormat) {
IndexMetadata indexMetadata = mock(IndexMetadata.class); IndexMetadata indexMetadata = mock(IndexMetadata.class);
var settings = Settings.builder() var indexSettings = Settings.builder()
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current()) .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
.build(); .build();
when(indexMetadata.getSettings()).thenReturn(settings); when(indexMetadata.getSettings()).thenReturn(indexSettings);
ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build()); ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build());
when(project.index(anyString())).thenReturn(indexMetadata); when(project.index(anyString())).thenReturn(indexMetadata);
@ -576,7 +567,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build(); ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
ClusterService clusterService = mock(ClusterService.class); ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(clusterState); when(clusterService.state()).thenReturn(clusterState);
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
return clusterService; return clusterService;
} }
@ -587,7 +581,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
) throws IOException { ) throws IOException {
Map<String, Object> docMap = new LinkedHashMap<>(); Map<String, Object> docMap = new LinkedHashMap<>();
Map<String, Object> expectedDocMap = new LinkedHashMap<>(); Map<String, Object> expectedDocMap = new LinkedHashMap<>();
XContentType requestContentType = randomFrom(XContentType.values()); // force JSON to avoid double/float conversions
XContentType requestContentType = XContentType.JSON;
Map<String, Object> inferenceMetadataFields = new HashMap<>(); Map<String, Object> inferenceMetadataFields = new HashMap<>();
for (var entry : fieldInferenceMap.values()) { for (var entry : fieldInferenceMap.values()) {

View file

@ -41,6 +41,7 @@ import java.util.function.Predicate;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -274,7 +275,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
while (inputsIt.hasNext() && chunkIt.hasNext()) { while (inputsIt.hasNext() && chunkIt.hasNext()) {
String input = inputsIt.next(); String input = inputsIt.next();
var chunk = chunkIt.next(); var chunk = chunkIt.next();
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat)); chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));
// When using the inference metadata fields format, all the input values are concatenated so that the // When using the inference metadata fields format, all the input values are concatenated so that the
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment