Optimize memory usage in ShardBulkInferenceActionFilter (#124313)

This refactor improves memory efficiency by processing inference requests in batches, capped by a max input length.

Changes include:
- A new dynamic operator setting to control the maximum batch size in bytes.
- Dropping input data from inference responses when the legacy semantic text format isn’t used, saving memory.
- Clearing inference results dynamically after each bulk item to free up memory sooner.

This is a step toward enabling circuit breakers to better handle memory usage when dealing with large inputs.
This commit is contained in:
Jim Ferenczi 2025-03-14 09:51:03 +00:00 committed by GitHub
parent 35ecbf6e87
commit 361b51d436
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 316 additions and 248 deletions

View file

@ -0,0 +1,5 @@
pr: 124313
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
area: Search
type: enhancement
issues: []

View file

@ -20,6 +20,7 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
@ -44,6 +45,7 @@ import java.util.Locale;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
@ -85,7 +87,12 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
return Settings.builder()
.put(otherSettings)
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
.build();
}
@Override

View file

@ -142,6 +142,7 @@ import java.util.function.Predicate;
import java.util.function.Supplier;
import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
public class InferencePlugin extends Plugin
@ -442,6 +443,7 @@ public class InferencePlugin extends Plugin
settings.addAll(Truncator.getSettingsDefinitions());
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
return settings;

View file

@ -25,7 +25,11 @@ import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
@ -43,6 +47,10 @@ import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.InferenceException;
@ -63,6 +71,8 @@ import java.util.Map;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
/**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
@ -72,10 +82,23 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FE
* This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the
* results during indexing on the shard.
*
* TODO: batchSize should be configurable via a cluster setting
*/
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
protected static final int DEFAULT_BATCH_SIZE = 512;
private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
/**
* Defines the cumulative size limit of input data before triggering a batch inference call.
* This setting controls how much data can be accumulated before an inference request is sent in batch.
*/
public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
"indices.inference.batch_size",
DEFAULT_BATCH_SIZE,
ByteSizeValue.ONE,
ByteSizeValue.ofMb(100),
Setting.Property.NodeScope,
Setting.Property.OperatorDynamic
);
private static final Object EXPLICIT_NULL = new Object();
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();
@ -83,29 +106,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final int batchSize;
private volatile long batchSizeInBytes;
public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
}
public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
int batchSize
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.batchSize = batchSize;
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
}
private void setBatchSize(ByteSizeValue newBatchSize) {
batchSizeInBytes = newBatchSize.getBytes();
}
@Override
@ -148,14 +166,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
/**
* A field inference request on a single input.
* @param index The index of the request in the original bulk request.
* @param bulkItemIndex The index of the item in the original bulk request.
* @param field The target field.
* @param sourceField The source field.
* @param input The input to run inference on.
* @param inputOrder The original order of the input.
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
*/
private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {}
private record FieldInferenceRequest(
int bulkItemIndex,
String field,
String sourceField,
String input,
int inputOrder,
int offsetAdjustment
) {}
/**
* The field inference response.
@ -218,29 +243,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
@Override
public void run() {
Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest);
executeNext(0);
}
private void executeNext(int itemOffset) {
if (itemOffset >= bulkShardRequest.items().length) {
onCompletion.run();
return;
}
var items = bulkShardRequest.items();
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
long totalInputLength = 0;
int itemIndex = itemOffset;
while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
var item = items[itemIndex];
totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
itemIndex += 1;
}
int nextItemOffset = itemIndex;
Runnable onInferenceCompletion = () -> {
try {
for (var inferenceResponse : inferenceResults.asList()) {
var request = bulkShardRequest.items()[inferenceResponse.id];
try {
applyInferenceResponses(request, inferenceResponse);
} catch (Exception exc) {
request.abort(bulkShardRequest.index(), exc);
for (int i = itemOffset; i < nextItemOffset; i++) {
var result = inferenceResults.get(i);
if (result == null) {
continue;
}
var item = items[i];
try {
applyInferenceResponses(item, result);
} catch (Exception exc) {
item.abort(bulkShardRequest.index(), exc);
}
// we don't need to keep the inference results around
inferenceResults.set(i, null);
}
} finally {
onCompletion.run();
executeNext(nextItemOffset);
}
};
try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
for (var entry : inferenceRequests.entrySet()) {
executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
for (var entry : fieldRequestsMap.entrySet()) {
executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
}
}
}
private void executeShardBulkInferenceAsync(
private void executeChunkedInferenceAsync(
final String inferenceId,
@Nullable InferenceProvider inferenceProvider,
final List<FieldInferenceRequest> requests,
@ -262,11 +312,11 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
unparsedModel.secrets()
)
);
executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish);
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
} else {
try (onFinish) {
for (FieldInferenceRequest request : requests) {
inferenceResults.get(request.index).failures.add(
inferenceResults.get(request.bulkItemIndex).failures.add(
new ResourceNotFoundException(
"Inference service [{}] not found for field [{}]",
unparsedModel.service(),
@ -297,7 +347,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
request.field
);
}
inferenceResults.get(request.index).failures.add(failure);
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
}
}
}
@ -305,18 +355,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
return;
}
int currentBatchSize = Math.min(requests.size(), batchSize);
final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
final List<String> inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
@Override
public void onResponse(List<ChunkedInference> results) {
try {
try (onFinish) {
var requestsIterator = requests.iterator();
for (ChunkedInference result : results) {
var request = requestsIterator.next();
var acc = inferenceResults.get(request.index);
var acc = inferenceResults.get(request.bulkItemIndex);
if (result instanceof ChunkedInferenceError error) {
acc.addFailure(
new InferenceException(
@ -331,7 +378,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
new FieldInferenceResponse(
request.field(),
request.sourceField(),
request.input(),
useLegacyFormat ? request.input() : null,
request.inputOrder(),
request.offsetAdjustment(),
inferenceProvider.model,
@ -340,17 +387,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
);
}
}
} finally {
onFinish();
}
}
@Override
public void onFailure(Exception exc) {
try {
try (onFinish) {
for (FieldInferenceRequest request : requests) {
addInferenceResponseFailure(
request.index,
request.bulkItemIndex,
new InferenceException(
"Exception when running inference id [{}] on field [{}]",
exc,
@ -359,16 +404,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
)
);
}
} finally {
onFinish();
}
}
private void onFinish() {
if (nextBatch.isEmpty()) {
onFinish.close();
} else {
executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish);
}
}
};
@ -376,6 +411,132 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
.chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener);
}
/**
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
* for the specified {@code item}.
*
* @param item The bulk request item to process.
* @param itemIndex The position of the item within the original bulk request.
* @param requestsMap A map storing inference requests, where each key is an inference ID,
* and the value is a list of associated {@link FieldInferenceRequest} objects.
* @return The total content length of all newly added requests, or {@code 0} if no requests were added.
*/
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
boolean isUpdateRequest = false;
final IndexRequest indexRequest;
if (item.request() instanceof IndexRequest ir) {
indexRequest = ir;
} else if (item.request() instanceof UpdateRequest updateRequest) {
isUpdateRequest = true;
if (updateRequest.script() != null) {
addInferenceResponseFailure(
itemIndex,
new ElasticsearchStatusException(
"Cannot apply update with a script on indices that contain [{}] field(s)",
RestStatus.BAD_REQUEST,
SemanticTextFieldMapper.CONTENT_TYPE
)
);
return 0;
}
indexRequest = updateRequest.doc();
} else {
// ignore delete request
return 0;
}
final Map<String, Object> docMap = indexRequest.sourceAsMap();
long inputLength = 0;
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
String inferenceId = entry.getInferenceId();
if (useLegacyFormat) {
var originalFieldValue = XContentMapValues.extractValue(field, docMap);
if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) {
// Inference has already been computed, or there is no inference required.
continue;
}
} else {
var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + "." + field,
docMap,
EXPLICIT_NULL
);
if (inferenceMetadataFieldsValue != null) {
// Inference has already been computed
continue;
}
}
int order = 0;
for (var sourceField : entry.getSourceFields()) {
var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
/**
* It's an update request, and the source field is explicitly set to null,
* so we need to propagate this information to the inference fields metadata
* to overwrite any inference previously computed on the field.
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
);
continue;
}
if (valueObj == null || valueObj == EXPLICIT_NULL) {
if (isUpdateRequest && useLegacyFormat) {
addInferenceResponseFailure(
itemIndex,
new ElasticsearchStatusException(
"Field [{}] must be specified on an update request to calculate inference for field [{}]",
RestStatus.BAD_REQUEST,
sourceField,
field
)
);
break;
}
continue;
}
var slot = ensureResponseAccumulatorSlot(itemIndex);
final List<String> values;
try {
values = SemanticTextUtils.nodeStringValues(field, valueObj);
} catch (Exception exc) {
addInferenceResponseFailure(itemIndex, exc);
break;
}
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
break;
}
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
inputLength += v.length();
if (v.isBlank()) {
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
);
} else {
requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
}
// When using the inference metadata fields format, all the input values are concatenated so that the
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
// to apply to account for this.
offsetAdjustment += v.length() + 1; // Add one for separator char length
}
}
}
return inputLength;
}
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
if (acc == null) {
@ -404,7 +565,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
}
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
var newDocMap = indexRequest.sourceAsMap();
Map<String, Object> inferenceFieldsMap = new HashMap<>();
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
@ -426,28 +586,22 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
}
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
lst.addAll(
SemanticTextField.toSemanticTextFieldChunks(
resp.input,
resp.offsetAdjustment,
resp.chunkedResults,
indexRequest.getContentType(),
useLegacyFormat
)
);
var chunks = useLegacyFormat
? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType())
: toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType());
lst.addAll(chunks);
}
List<String> inputs = responses.stream()
.filter(r -> r.sourceField().equals(fieldName))
.map(r -> r.input)
.collect(Collectors.toList());
List<String> inputs = useLegacyFormat
? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList())
: null;
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
// merge in the field's existing model settings on the data node.
var result = new SemanticTextField(
useLegacyFormat,
fieldName,
useLegacyFormat ? inputs : null,
inputs,
new SemanticTextField.InferenceResult(
inferenceFieldMetadata.getInferenceId(),
model != null ? new MinimalServiceSettings(model) : null,
@ -455,149 +609,52 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
),
indexRequest.getContentType()
);
inferenceFieldsMap.put(fieldName, result);
}
if (useLegacyFormat) {
SemanticTextUtils.insertValue(fieldName, newDocMap, result);
} else {
inferenceFieldsMap.put(fieldName, result);
if (useLegacyFormat) {
var newDocMap = indexRequest.sourceAsMap();
for (var entry : inferenceFieldsMap.entrySet()) {
SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue());
}
indexRequest.source(newDocMap, indexRequest.getContentType());
} else {
try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap);
indexRequest.source(builder);
}
}
if (useLegacyFormat == false) {
newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap);
}
}
/**
* Appends the original source and the new inference metadata field directly to the provided
* {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
*/
private static void appendSourceAndInferenceMetadata(
XContentBuilder builder,
BytesReference source,
XContentType xContentType,
Map<String, Object> inferenceFieldsMap
) throws IOException {
builder.startObject();
// append the original source
try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) {
// skip start object
parser.nextToken();
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
builder.copyCurrentStructure(parser);
}
indexRequest.source(newDocMap, indexRequest.getContentType());
}
/**
* Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index.
* If results are already populated for fields in the original index request, the inference request for this specific
* field is skipped, and the existing results remain unchanged.
* Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing,
* where an error will be thrown if they mismatch or if the content is malformed.
* <p>
* TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ?
*/
private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) {
var item = bulkShardRequest.items()[itemIndex];
if (item.getPrimaryResponse() != null) {
// item was already aborted/processed by a filter in the chain upstream (e.g. security)
continue;
}
boolean isUpdateRequest = false;
final IndexRequest indexRequest;
if (item.request() instanceof IndexRequest ir) {
indexRequest = ir;
} else if (item.request() instanceof UpdateRequest updateRequest) {
isUpdateRequest = true;
if (updateRequest.script() != null) {
addInferenceResponseFailure(
itemIndex,
new ElasticsearchStatusException(
"Cannot apply update with a script on indices that contain [{}] field(s)",
RestStatus.BAD_REQUEST,
SemanticTextFieldMapper.CONTENT_TYPE
)
);
continue;
}
indexRequest = updateRequest.doc();
} else {
// ignore delete request
continue;
}
final Map<String, Object> docMap = indexRequest.sourceAsMap();
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
String inferenceId = entry.getInferenceId();
if (useLegacyFormat) {
var originalFieldValue = XContentMapValues.extractValue(field, docMap);
if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) {
// Inference has already been computed, or there is no inference required.
continue;
}
} else {
var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + "." + field,
docMap,
EXPLICIT_NULL
);
if (inferenceMetadataFieldsValue != null) {
// Inference has already been computed
continue;
}
}
int order = 0;
for (var sourceField : entry.getSourceFields()) {
var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
/**
* It's an update request, and the source field is explicitly set to null,
* so we need to propagate this information to the inference fields metadata
* to overwrite any inference previously computed on the field.
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
);
continue;
}
if (valueObj == null || valueObj == EXPLICIT_NULL) {
if (isUpdateRequest && useLegacyFormat) {
addInferenceResponseFailure(
itemIndex,
new ElasticsearchStatusException(
"Field [{}] must be specified on an update request to calculate inference for field [{}]",
RestStatus.BAD_REQUEST,
sourceField,
field
)
);
break;
}
continue;
}
var slot = ensureResponseAccumulatorSlot(itemIndex);
final List<String> values;
try {
values = SemanticTextUtils.nodeStringValues(field, valueObj);
} catch (Exception exc) {
addInferenceResponseFailure(itemIndex, exc);
break;
}
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
break;
}
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
if (v.isBlank()) {
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
);
} else {
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
}
// When using the inference metadata fields format, all the input values are concatenated so that the
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
// to apply to account for this.
offsetAdjustment += v.length() + 1; // Add one for separator char length
}
}
}
}
return fieldRequestsMap;
// add the inference metadata field
builder.field(InferenceMetadataFieldsMapper.NAME);
try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
builder.copyCurrentStructure(parser);
}
builder.endObject();
}
static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {

View file

@ -267,37 +267,38 @@ public record SemanticTextField(
/**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/
public static List<Chunk> toSemanticTextFieldChunks(
String input,
int offsetAdjustment,
ChunkedInference results,
XContentType contentType,
boolean useLegacyFormat
) throws IOException {
public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
}
return chunks;
}
public static Chunk toSemanticTextFieldChunk(
String input,
int offsetAdjustment,
ChunkedInference.Chunk chunk,
boolean useLegacyFormat
) {
/**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/
public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
String text = null;
int startOffset = -1;
int endOffset = -1;
if (useLegacyFormat) {
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
} else {
startOffset = chunk.textOffset().start() + offsetAdjustment;
endOffset = chunk.textOffset().end() + offsetAdjustment;
}
int startOffset = chunk.textOffset().start() + offsetAdjustment;
int endOffset = chunk.textOffset().end() + offsetAdjustment;
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
}
public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
}
return chunks;
}
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
return new Chunk(text, -1, -1, chunk.bytesReference());
}
}

View file

@ -28,7 +28,9 @@ import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion;
@ -66,12 +68,13 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
@ -118,7 +121,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
@ -144,7 +147,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() throws InterruptedException {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
@ -185,7 +188,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
@ -232,7 +234,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
@ -303,7 +304,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
@ -374,7 +374,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
@ -447,13 +446,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
modifiedRequests[id] = res[1];
}
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
inferenceModelMap,
randomIntBetween(10, 30),
useLegacyFormat,
true
);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
@ -487,7 +480,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
private static ShardBulkInferenceActionFilter createFilter(
ThreadPool threadPool,
Map<String, StaticModel> modelMap,
int batchSize,
boolean useLegacyFormat,
boolean isLicenseValidForInference
) {
@ -554,18 +546,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
createClusterService(useLegacyFormat),
inferenceServiceRegistry,
modelRegistry,
licenseState,
batchSize
licenseState
);
}
private static ClusterService createClusterService(boolean useLegacyFormat) {
IndexMetadata indexMetadata = mock(IndexMetadata.class);
var settings = Settings.builder()
var indexSettings = Settings.builder()
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
.build();
when(indexMetadata.getSettings()).thenReturn(settings);
when(indexMetadata.getSettings()).thenReturn(indexSettings);
ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build());
when(project.index(anyString())).thenReturn(indexMetadata);
@ -576,7 +567,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(clusterState);
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
return clusterService;
}
@ -587,7 +581,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
) throws IOException {
Map<String, Object> docMap = new LinkedHashMap<>();
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
XContentType requestContentType = randomFrom(XContentType.values());
// force JSON to avoid double/float conversions
XContentType requestContentType = XContentType.JSON;
Map<String, Object> inferenceMetadataFields = new HashMap<>();
for (var entry : fieldInferenceMap.values()) {

View file

@ -41,6 +41,7 @@ import java.util.function.Predicate;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
@ -274,7 +275,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
while (inputsIt.hasNext() && chunkIt.hasNext()) {
String input = inputsIt.next();
var chunk = chunkIt.next();
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));
// When using the inference metadata fields format, all the input values are concatenated so that the
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment