mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
Optimize memory usage in ShardBulkInferenceActionFilter (#124313)
This refactor improves memory efficiency by processing inference requests in batches, capped by a max input length. Changes include: - A new dynamic operator setting to control the maximum batch size in bytes. - Dropping input data from inference responses when the legacy semantic text format isn’t used, saving memory. - Clearing inference results dynamically after each bulk item to free up memory sooner. This is a step toward enabling circuit breakers to better handle memory usage when dealing with large inputs.
This commit is contained in:
parent
35ecbf6e87
commit
361b51d436
7 changed files with 316 additions and 248 deletions
5
docs/changelog/124313.yaml
Normal file
5
docs/changelog/124313.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 124313
|
||||
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
|
||||
area: Search
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.action.search.SearchResponse;
|
|||
import org.elasticsearch.action.update.UpdateRequestBuilder;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.index.IndexSettings;
|
||||
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
|
||||
import org.elasticsearch.index.mapper.SourceFieldMapper;
|
||||
|
@ -44,6 +45,7 @@ import java.util.Locale;
|
|||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -85,7 +87,12 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
|
|||
|
||||
@Override
|
||||
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
|
||||
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
|
||||
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
|
||||
return Settings.builder()
|
||||
.put(otherSettings)
|
||||
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
|
||||
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -142,6 +142,7 @@ import java.util.function.Predicate;
|
|||
import java.util.function.Supplier;
|
||||
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
|
||||
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
|
||||
|
||||
public class InferencePlugin extends Plugin
|
||||
|
@ -442,6 +443,7 @@ public class InferencePlugin extends Plugin
|
|||
settings.addAll(Truncator.getSettingsDefinitions());
|
||||
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
|
||||
settings.add(SKIP_VALIDATE_AND_START);
|
||||
settings.add(INDICES_INFERENCE_BATCH_SIZE);
|
||||
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
|
||||
|
||||
return settings;
|
||||
|
|
|
@ -25,7 +25,11 @@ import org.elasticsearch.action.update.UpdateRequest;
|
|||
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
||||
import org.elasticsearch.cluster.metadata.ProjectMetadata;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Setting;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.concurrent.AtomicArray;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
|
@ -43,6 +47,10 @@ import org.elasticsearch.license.XPackLicenseState;
|
|||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.xcontent.XContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
||||
import org.elasticsearch.xpack.inference.InferenceException;
|
||||
|
@ -63,6 +71,8 @@ import java.util.Map;
|
|||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
|
||||
|
||||
/**
|
||||
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
|
||||
|
@ -72,10 +82,23 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FE
|
|||
* This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the
|
||||
* results during indexing on the shard.
|
||||
*
|
||||
* TODO: batchSize should be configurable via a cluster setting
|
||||
*/
|
||||
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
||||
protected static final int DEFAULT_BATCH_SIZE = 512;
|
||||
private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
|
||||
|
||||
/**
|
||||
* Defines the cumulative size limit of input data before triggering a batch inference call.
|
||||
* This setting controls how much data can be accumulated before an inference request is sent in batch.
|
||||
*/
|
||||
public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
|
||||
"indices.inference.batch_size",
|
||||
DEFAULT_BATCH_SIZE,
|
||||
ByteSizeValue.ONE,
|
||||
ByteSizeValue.ofMb(100),
|
||||
Setting.Property.NodeScope,
|
||||
Setting.Property.OperatorDynamic
|
||||
);
|
||||
|
||||
private static final Object EXPLICIT_NULL = new Object();
|
||||
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();
|
||||
|
||||
|
@ -83,29 +106,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
private final InferenceServiceRegistry inferenceServiceRegistry;
|
||||
private final ModelRegistry modelRegistry;
|
||||
private final XPackLicenseState licenseState;
|
||||
private final int batchSize;
|
||||
private volatile long batchSizeInBytes;
|
||||
|
||||
public ShardBulkInferenceActionFilter(
|
||||
ClusterService clusterService,
|
||||
InferenceServiceRegistry inferenceServiceRegistry,
|
||||
ModelRegistry modelRegistry,
|
||||
XPackLicenseState licenseState
|
||||
) {
|
||||
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
|
||||
}
|
||||
|
||||
public ShardBulkInferenceActionFilter(
|
||||
ClusterService clusterService,
|
||||
InferenceServiceRegistry inferenceServiceRegistry,
|
||||
ModelRegistry modelRegistry,
|
||||
XPackLicenseState licenseState,
|
||||
int batchSize
|
||||
) {
|
||||
this.clusterService = clusterService;
|
||||
this.inferenceServiceRegistry = inferenceServiceRegistry;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.licenseState = licenseState;
|
||||
this.batchSize = batchSize;
|
||||
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
|
||||
}
|
||||
|
||||
private void setBatchSize(ByteSizeValue newBatchSize) {
|
||||
batchSizeInBytes = newBatchSize.getBytes();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -148,14 +166,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
|
||||
/**
|
||||
* A field inference request on a single input.
|
||||
* @param index The index of the request in the original bulk request.
|
||||
* @param bulkItemIndex The index of the item in the original bulk request.
|
||||
* @param field The target field.
|
||||
* @param sourceField The source field.
|
||||
* @param input The input to run inference on.
|
||||
* @param inputOrder The original order of the input.
|
||||
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
|
||||
*/
|
||||
private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {}
|
||||
private record FieldInferenceRequest(
|
||||
int bulkItemIndex,
|
||||
String field,
|
||||
String sourceField,
|
||||
String input,
|
||||
int inputOrder,
|
||||
int offsetAdjustment
|
||||
) {}
|
||||
|
||||
/**
|
||||
* The field inference response.
|
||||
|
@ -218,29 +243,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
|
||||
@Override
|
||||
public void run() {
|
||||
Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest);
|
||||
executeNext(0);
|
||||
}
|
||||
|
||||
private void executeNext(int itemOffset) {
|
||||
if (itemOffset >= bulkShardRequest.items().length) {
|
||||
onCompletion.run();
|
||||
return;
|
||||
}
|
||||
|
||||
var items = bulkShardRequest.items();
|
||||
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
|
||||
long totalInputLength = 0;
|
||||
int itemIndex = itemOffset;
|
||||
while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
|
||||
var item = items[itemIndex];
|
||||
totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
|
||||
itemIndex += 1;
|
||||
}
|
||||
int nextItemOffset = itemIndex;
|
||||
Runnable onInferenceCompletion = () -> {
|
||||
try {
|
||||
for (var inferenceResponse : inferenceResults.asList()) {
|
||||
var request = bulkShardRequest.items()[inferenceResponse.id];
|
||||
try {
|
||||
applyInferenceResponses(request, inferenceResponse);
|
||||
} catch (Exception exc) {
|
||||
request.abort(bulkShardRequest.index(), exc);
|
||||
for (int i = itemOffset; i < nextItemOffset; i++) {
|
||||
var result = inferenceResults.get(i);
|
||||
if (result == null) {
|
||||
continue;
|
||||
}
|
||||
var item = items[i];
|
||||
try {
|
||||
applyInferenceResponses(item, result);
|
||||
} catch (Exception exc) {
|
||||
item.abort(bulkShardRequest.index(), exc);
|
||||
}
|
||||
// we don't need to keep the inference results around
|
||||
inferenceResults.set(i, null);
|
||||
}
|
||||
} finally {
|
||||
onCompletion.run();
|
||||
executeNext(nextItemOffset);
|
||||
}
|
||||
};
|
||||
|
||||
try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
|
||||
for (var entry : inferenceRequests.entrySet()) {
|
||||
executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
|
||||
for (var entry : fieldRequestsMap.entrySet()) {
|
||||
executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void executeShardBulkInferenceAsync(
|
||||
private void executeChunkedInferenceAsync(
|
||||
final String inferenceId,
|
||||
@Nullable InferenceProvider inferenceProvider,
|
||||
final List<FieldInferenceRequest> requests,
|
||||
|
@ -262,11 +312,11 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
unparsedModel.secrets()
|
||||
)
|
||||
);
|
||||
executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish);
|
||||
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
|
||||
} else {
|
||||
try (onFinish) {
|
||||
for (FieldInferenceRequest request : requests) {
|
||||
inferenceResults.get(request.index).failures.add(
|
||||
inferenceResults.get(request.bulkItemIndex).failures.add(
|
||||
new ResourceNotFoundException(
|
||||
"Inference service [{}] not found for field [{}]",
|
||||
unparsedModel.service(),
|
||||
|
@ -297,7 +347,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
request.field
|
||||
);
|
||||
}
|
||||
inferenceResults.get(request.index).failures.add(failure);
|
||||
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -305,18 +355,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
|
||||
return;
|
||||
}
|
||||
int currentBatchSize = Math.min(requests.size(), batchSize);
|
||||
final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
|
||||
final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
|
||||
final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
|
||||
final List<String> inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
|
||||
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
|
||||
@Override
|
||||
public void onResponse(List<ChunkedInference> results) {
|
||||
try {
|
||||
try (onFinish) {
|
||||
var requestsIterator = requests.iterator();
|
||||
for (ChunkedInference result : results) {
|
||||
var request = requestsIterator.next();
|
||||
var acc = inferenceResults.get(request.index);
|
||||
var acc = inferenceResults.get(request.bulkItemIndex);
|
||||
if (result instanceof ChunkedInferenceError error) {
|
||||
acc.addFailure(
|
||||
new InferenceException(
|
||||
|
@ -331,7 +378,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
new FieldInferenceResponse(
|
||||
request.field(),
|
||||
request.sourceField(),
|
||||
request.input(),
|
||||
useLegacyFormat ? request.input() : null,
|
||||
request.inputOrder(),
|
||||
request.offsetAdjustment(),
|
||||
inferenceProvider.model,
|
||||
|
@ -340,17 +387,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
onFinish();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception exc) {
|
||||
try {
|
||||
try (onFinish) {
|
||||
for (FieldInferenceRequest request : requests) {
|
||||
addInferenceResponseFailure(
|
||||
request.index,
|
||||
request.bulkItemIndex,
|
||||
new InferenceException(
|
||||
"Exception when running inference id [{}] on field [{}]",
|
||||
exc,
|
||||
|
@ -359,16 +404,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
)
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
onFinish();
|
||||
}
|
||||
}
|
||||
|
||||
private void onFinish() {
|
||||
if (nextBatch.isEmpty()) {
|
||||
onFinish.close();
|
||||
} else {
|
||||
executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -376,115 +411,17 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
.chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener);
|
||||
}
|
||||
|
||||
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
|
||||
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
|
||||
if (acc == null) {
|
||||
acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>());
|
||||
inferenceResults.set(id, acc);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
private void addInferenceResponseFailure(int id, Exception failure) {
|
||||
var acc = ensureResponseAccumulatorSlot(id);
|
||||
acc.addFailure(failure);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}.
|
||||
* If the response contains failures, the bulk item request is marked as failed for the downstream action.
|
||||
* Otherwise, the source of the request is augmented with the field inference results.
|
||||
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
|
||||
* for the specified {@code item}.
|
||||
*
|
||||
* @param item The bulk request item to process.
|
||||
* @param itemIndex The position of the item within the original bulk request.
|
||||
* @param requestsMap A map storing inference requests, where each key is an inference ID,
|
||||
* and the value is a list of associated {@link FieldInferenceRequest} objects.
|
||||
* @return The total content length of all newly added requests, or {@code 0} if no requests were added.
|
||||
*/
|
||||
private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException {
|
||||
if (response.failures().isEmpty() == false) {
|
||||
for (var failure : response.failures()) {
|
||||
item.abort(item.index(), failure);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
|
||||
var newDocMap = indexRequest.sourceAsMap();
|
||||
Map<String, Object> inferenceFieldsMap = new HashMap<>();
|
||||
for (var entry : response.responses.entrySet()) {
|
||||
var fieldName = entry.getKey();
|
||||
var responses = entry.getValue();
|
||||
Model model = null;
|
||||
|
||||
InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
|
||||
if (inferenceFieldMetadata == null) {
|
||||
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
|
||||
}
|
||||
|
||||
// ensure that the order in the original field is consistent in case of multiple inputs
|
||||
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
|
||||
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
|
||||
for (var resp : responses) {
|
||||
// Get the first non-null model from the response list
|
||||
if (model == null) {
|
||||
model = resp.model;
|
||||
}
|
||||
|
||||
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
|
||||
lst.addAll(
|
||||
SemanticTextField.toSemanticTextFieldChunks(
|
||||
resp.input,
|
||||
resp.offsetAdjustment,
|
||||
resp.chunkedResults,
|
||||
indexRequest.getContentType(),
|
||||
useLegacyFormat
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
List<String> inputs = responses.stream()
|
||||
.filter(r -> r.sourceField().equals(fieldName))
|
||||
.map(r -> r.input)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
|
||||
// merge in the field's existing model settings on the data node.
|
||||
var result = new SemanticTextField(
|
||||
useLegacyFormat,
|
||||
fieldName,
|
||||
useLegacyFormat ? inputs : null,
|
||||
new SemanticTextField.InferenceResult(
|
||||
inferenceFieldMetadata.getInferenceId(),
|
||||
model != null ? new MinimalServiceSettings(model) : null,
|
||||
chunkMap
|
||||
),
|
||||
indexRequest.getContentType()
|
||||
);
|
||||
|
||||
if (useLegacyFormat) {
|
||||
SemanticTextUtils.insertValue(fieldName, newDocMap, result);
|
||||
} else {
|
||||
inferenceFieldsMap.put(fieldName, result);
|
||||
}
|
||||
}
|
||||
if (useLegacyFormat == false) {
|
||||
newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap);
|
||||
}
|
||||
indexRequest.source(newDocMap, indexRequest.getContentType());
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index.
|
||||
* If results are already populated for fields in the original index request, the inference request for this specific
|
||||
* field is skipped, and the existing results remain unchanged.
|
||||
* Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing,
|
||||
* where an error will be thrown if they mismatch or if the content is malformed.
|
||||
* <p>
|
||||
* TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ?
|
||||
*/
|
||||
private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
|
||||
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
|
||||
for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) {
|
||||
var item = bulkShardRequest.items()[itemIndex];
|
||||
if (item.getPrimaryResponse() != null) {
|
||||
// item was already aborted/processed by a filter in the chain upstream (e.g. security)
|
||||
continue;
|
||||
}
|
||||
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
|
||||
boolean isUpdateRequest = false;
|
||||
final IndexRequest indexRequest;
|
||||
if (item.request() instanceof IndexRequest ir) {
|
||||
|
@ -500,15 +437,16 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
SemanticTextFieldMapper.CONTENT_TYPE
|
||||
)
|
||||
);
|
||||
continue;
|
||||
return 0;
|
||||
}
|
||||
indexRequest = updateRequest.doc();
|
||||
} else {
|
||||
// ignore delete request
|
||||
continue;
|
||||
return 0;
|
||||
}
|
||||
|
||||
final Map<String, Object> docMap = indexRequest.sourceAsMap();
|
||||
long inputLength = 0;
|
||||
for (var entry : fieldInferenceMap.values()) {
|
||||
String field = entry.getName();
|
||||
String inferenceId = entry.getInferenceId();
|
||||
|
@ -577,15 +515,16 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
break;
|
||||
}
|
||||
|
||||
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
|
||||
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
|
||||
int offsetAdjustment = 0;
|
||||
for (String v : values) {
|
||||
inputLength += v.length();
|
||||
if (v.isBlank()) {
|
||||
slot.addOrUpdateResponse(
|
||||
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
|
||||
);
|
||||
} else {
|
||||
fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
|
||||
requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
|
||||
}
|
||||
|
||||
// When using the inference metadata fields format, all the input values are concatenated so that the
|
||||
|
@ -595,9 +534,127 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
}
|
||||
}
|
||||
}
|
||||
return inputLength;
|
||||
}
|
||||
return fieldRequestsMap;
|
||||
|
||||
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
|
||||
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
|
||||
if (acc == null) {
|
||||
acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>());
|
||||
inferenceResults.set(id, acc);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
private void addInferenceResponseFailure(int id, Exception failure) {
|
||||
var acc = ensureResponseAccumulatorSlot(id);
|
||||
acc.addFailure(failure);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}.
|
||||
* If the response contains failures, the bulk item request is marked as failed for the downstream action.
|
||||
* Otherwise, the source of the request is augmented with the field inference results.
|
||||
*/
|
||||
private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException {
|
||||
if (response.failures().isEmpty() == false) {
|
||||
for (var failure : response.failures()) {
|
||||
item.abort(item.index(), failure);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
|
||||
Map<String, Object> inferenceFieldsMap = new HashMap<>();
|
||||
for (var entry : response.responses.entrySet()) {
|
||||
var fieldName = entry.getKey();
|
||||
var responses = entry.getValue();
|
||||
Model model = null;
|
||||
|
||||
InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
|
||||
if (inferenceFieldMetadata == null) {
|
||||
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
|
||||
}
|
||||
|
||||
// ensure that the order in the original field is consistent in case of multiple inputs
|
||||
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
|
||||
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
|
||||
for (var resp : responses) {
|
||||
// Get the first non-null model from the response list
|
||||
if (model == null) {
|
||||
model = resp.model;
|
||||
}
|
||||
|
||||
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
|
||||
var chunks = useLegacyFormat
|
||||
? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType())
|
||||
: toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType());
|
||||
lst.addAll(chunks);
|
||||
}
|
||||
|
||||
List<String> inputs = useLegacyFormat
|
||||
? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList())
|
||||
: null;
|
||||
|
||||
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
|
||||
// merge in the field's existing model settings on the data node.
|
||||
var result = new SemanticTextField(
|
||||
useLegacyFormat,
|
||||
fieldName,
|
||||
inputs,
|
||||
new SemanticTextField.InferenceResult(
|
||||
inferenceFieldMetadata.getInferenceId(),
|
||||
model != null ? new MinimalServiceSettings(model) : null,
|
||||
chunkMap
|
||||
),
|
||||
indexRequest.getContentType()
|
||||
);
|
||||
inferenceFieldsMap.put(fieldName, result);
|
||||
}
|
||||
|
||||
if (useLegacyFormat) {
|
||||
var newDocMap = indexRequest.sourceAsMap();
|
||||
for (var entry : inferenceFieldsMap.entrySet()) {
|
||||
SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue());
|
||||
}
|
||||
indexRequest.source(newDocMap, indexRequest.getContentType());
|
||||
} else {
|
||||
try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
|
||||
appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap);
|
||||
indexRequest.source(builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Appends the original source and the new inference metadata field directly to the provided
|
||||
* {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
|
||||
*/
|
||||
private static void appendSourceAndInferenceMetadata(
|
||||
XContentBuilder builder,
|
||||
BytesReference source,
|
||||
XContentType xContentType,
|
||||
Map<String, Object> inferenceFieldsMap
|
||||
) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
// append the original source
|
||||
try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) {
|
||||
// skip start object
|
||||
parser.nextToken();
|
||||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
|
||||
builder.copyCurrentStructure(parser);
|
||||
}
|
||||
}
|
||||
|
||||
// add the inference metadata field
|
||||
builder.field(InferenceMetadataFieldsMapper.NAME);
|
||||
try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
|
||||
builder.copyCurrentStructure(parser);
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
}
|
||||
|
||||
static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
|
||||
|
|
|
@ -267,37 +267,38 @@ public record SemanticTextField(
|
|||
/**
|
||||
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
|
||||
*/
|
||||
public static List<Chunk> toSemanticTextFieldChunks(
|
||||
String input,
|
||||
int offsetAdjustment,
|
||||
ChunkedInference results,
|
||||
XContentType contentType,
|
||||
boolean useLegacyFormat
|
||||
) throws IOException {
|
||||
public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
|
||||
throws IOException {
|
||||
List<Chunk> chunks = new ArrayList<>();
|
||||
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
|
||||
while (it.hasNext()) {
|
||||
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
|
||||
chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
public static Chunk toSemanticTextFieldChunk(
|
||||
String input,
|
||||
int offsetAdjustment,
|
||||
ChunkedInference.Chunk chunk,
|
||||
boolean useLegacyFormat
|
||||
) {
|
||||
/**
|
||||
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
|
||||
*/
|
||||
public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
|
||||
String text = null;
|
||||
int startOffset = -1;
|
||||
int endOffset = -1;
|
||||
if (useLegacyFormat) {
|
||||
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
|
||||
} else {
|
||||
startOffset = chunk.textOffset().start() + offsetAdjustment;
|
||||
endOffset = chunk.textOffset().end() + offsetAdjustment;
|
||||
}
|
||||
|
||||
int startOffset = chunk.textOffset().start() + offsetAdjustment;
|
||||
int endOffset = chunk.textOffset().end() + offsetAdjustment;
|
||||
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
|
||||
}
|
||||
|
||||
public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
|
||||
throws IOException {
|
||||
List<Chunk> chunks = new ArrayList<>();
|
||||
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
|
||||
while (it.hasNext()) {
|
||||
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
|
||||
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
|
||||
return new Chunk(text, -1, -1, chunk.bytesReference());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,9 @@ import org.elasticsearch.cluster.metadata.Metadata;
|
|||
import org.elasticsearch.cluster.metadata.ProjectMetadata;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
||||
import org.elasticsearch.index.IndexVersion;
|
||||
|
@ -66,12 +68,13 @@ import java.util.LinkedHashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
|
||||
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
|
||||
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
|
||||
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
|
||||
|
@ -118,7 +121,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testFilterNoop() throws Exception {
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
try {
|
||||
|
@ -144,7 +147,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testLicenseInvalidForInference() throws InterruptedException {
|
||||
StaticModel model = StaticModel.createRandomInstance();
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
try {
|
||||
|
@ -185,7 +188,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
randomIntBetween(1, 10),
|
||||
useLegacyFormat,
|
||||
true
|
||||
);
|
||||
|
@ -232,7 +234,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
randomIntBetween(1, 10),
|
||||
useLegacyFormat,
|
||||
true
|
||||
);
|
||||
|
@ -303,7 +304,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
randomIntBetween(1, 10),
|
||||
useLegacyFormat,
|
||||
true
|
||||
);
|
||||
|
@ -374,7 +374,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
randomIntBetween(1, 10),
|
||||
useLegacyFormat,
|
||||
true
|
||||
);
|
||||
|
@ -447,13 +446,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
modifiedRequests[id] = res[1];
|
||||
}
|
||||
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
inferenceModelMap,
|
||||
randomIntBetween(10, 30),
|
||||
useLegacyFormat,
|
||||
true
|
||||
);
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
try {
|
||||
|
@ -487,7 +480,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
private static ShardBulkInferenceActionFilter createFilter(
|
||||
ThreadPool threadPool,
|
||||
Map<String, StaticModel> modelMap,
|
||||
int batchSize,
|
||||
boolean useLegacyFormat,
|
||||
boolean isLicenseValidForInference
|
||||
) {
|
||||
|
@ -554,18 +546,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
createClusterService(useLegacyFormat),
|
||||
inferenceServiceRegistry,
|
||||
modelRegistry,
|
||||
licenseState,
|
||||
batchSize
|
||||
licenseState
|
||||
);
|
||||
}
|
||||
|
||||
private static ClusterService createClusterService(boolean useLegacyFormat) {
|
||||
IndexMetadata indexMetadata = mock(IndexMetadata.class);
|
||||
var settings = Settings.builder()
|
||||
var indexSettings = Settings.builder()
|
||||
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
|
||||
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
|
||||
.build();
|
||||
when(indexMetadata.getSettings()).thenReturn(settings);
|
||||
when(indexMetadata.getSettings()).thenReturn(indexSettings);
|
||||
|
||||
ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build());
|
||||
when(project.index(anyString())).thenReturn(indexMetadata);
|
||||
|
@ -576,7 +567,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
|
||||
ClusterService clusterService = mock(ClusterService.class);
|
||||
when(clusterService.state()).thenReturn(clusterState);
|
||||
|
||||
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
|
||||
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
|
||||
when(clusterService.getSettings()).thenReturn(settings);
|
||||
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
|
||||
return clusterService;
|
||||
}
|
||||
|
||||
|
@ -587,7 +581,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
) throws IOException {
|
||||
Map<String, Object> docMap = new LinkedHashMap<>();
|
||||
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
|
||||
XContentType requestContentType = randomFrom(XContentType.values());
|
||||
// force JSON to avoid double/float conversions
|
||||
XContentType requestContentType = XContentType.JSON;
|
||||
|
||||
Map<String, Object> inferenceMetadataFields = new HashMap<>();
|
||||
for (var entry : fieldInferenceMap.values()) {
|
||||
|
|
|
@ -41,6 +41,7 @@ import java.util.function.Predicate;
|
|||
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
@ -274,7 +275,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
while (inputsIt.hasNext() && chunkIt.hasNext()) {
|
||||
String input = inputsIt.next();
|
||||
var chunk = chunkIt.next();
|
||||
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
|
||||
chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));
|
||||
|
||||
// When using the inference metadata fields format, all the input values are concatenated so that the
|
||||
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue