diff --git a/docs/changelog/124313.yaml b/docs/changelog/124313.yaml new file mode 100644 index 000000000000..fc4d4d9d815e --- /dev/null +++ b/docs/changelog/124313.yaml @@ -0,0 +1,5 @@ +pr: 124313 +summary: Optimize memory usage in `ShardBulkInferenceActionFilter` +area: Search +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index c85911c00947..5b3a95ca4014 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.SourceFieldMapper; @@ -44,6 +45,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -85,7 +87,12 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase { @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { - return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes()); + return Settings.builder() + .put(otherSettings) + .put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial") + .put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)) + .build(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7263f204808d..5f941d486d57 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -142,6 +142,7 @@ import java.util.function.Predicate; import java.util.function.Supplier; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG; public class InferencePlugin extends Plugin @@ -442,6 +443,7 @@ public class InferencePlugin extends Plugin settings.addAll(Truncator.getSettingsDefinitions()); settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions()); settings.add(SKIP_VALIDATE_AND_START); + settings.add(INDICES_INFERENCE_BATCH_SIZE); settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions()); return settings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 16336d941fd9..59555bfed4a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -25,7 +25,11 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; @@ -43,6 +47,10 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferenceException; @@ -63,6 +71,8 @@ import java.util.Map; import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified @@ -72,10 +82,23 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FE * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the * results during indexing on the shard. * - * TODO: batchSize should be configurable via a cluster setting */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { - protected static final int DEFAULT_BATCH_SIZE = 512; + private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1); + + /** + * Defines the cumulative size limit of input data before triggering a batch inference call. + * This setting controls how much data can be accumulated before an inference request is sent in batch. + */ + public static Setting INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting( + "indices.inference.batch_size", + DEFAULT_BATCH_SIZE, + ByteSizeValue.ONE, + ByteSizeValue.ofMb(100), + Setting.Property.NodeScope, + Setting.Property.OperatorDynamic + ); + private static final Object EXPLICIT_NULL = new Object(); private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference(); @@ -83,29 +106,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; private final XPackLicenseState licenseState; - private final int batchSize; + private volatile long batchSizeInBytes; public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, XPackLicenseState licenseState - ) { - this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE); - } - - public ShardBulkInferenceActionFilter( - ClusterService clusterService, - InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry, - XPackLicenseState licenseState, - int batchSize ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; this.licenseState = licenseState; - this.batchSize = batchSize; + this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes(); + clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize); + } + + private void setBatchSize(ByteSizeValue newBatchSize) { + batchSizeInBytes = newBatchSize.getBytes(); } @Override @@ -148,14 +166,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { /** * A field inference request on a single input. - * @param index The index of the request in the original bulk request. + * @param bulkItemIndex The index of the item in the original bulk request. * @param field The target field. * @param sourceField The source field. * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. */ - private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {} + private record FieldInferenceRequest( + int bulkItemIndex, + String field, + String sourceField, + String input, + int inputOrder, + int offsetAdjustment + ) {} /** * The field inference response. @@ -218,29 +243,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { @Override public void run() { - Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + executeNext(0); + } + + private void executeNext(int itemOffset) { + if (itemOffset >= bulkShardRequest.items().length) { + onCompletion.run(); + return; + } + + var items = bulkShardRequest.items(); + Map> fieldRequestsMap = new HashMap<>(); + long totalInputLength = 0; + int itemIndex = itemOffset; + while (itemIndex < items.length && totalInputLength < batchSizeInBytes) { + var item = items[itemIndex]; + totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap); + itemIndex += 1; + } + int nextItemOffset = itemIndex; Runnable onInferenceCompletion = () -> { try { - for (var inferenceResponse : inferenceResults.asList()) { - var request = bulkShardRequest.items()[inferenceResponse.id]; - try { - applyInferenceResponses(request, inferenceResponse); - } catch (Exception exc) { - request.abort(bulkShardRequest.index(), exc); + for (int i = itemOffset; i < nextItemOffset; i++) { + var result = inferenceResults.get(i); + if (result == null) { + continue; } + var item = items[i]; + try { + applyInferenceResponses(item, result); + } catch (Exception exc) { + item.abort(bulkShardRequest.index(), exc); + } + // we don't need to keep the inference results around + inferenceResults.set(i, null); } } finally { - onCompletion.run(); + executeNext(nextItemOffset); } }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { - for (var entry : inferenceRequests.entrySet()) { - executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + for (var entry : fieldRequestsMap.entrySet()) { + executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); } } } - private void executeShardBulkInferenceAsync( + private void executeChunkedInferenceAsync( final String inferenceId, @Nullable InferenceProvider inferenceProvider, final List requests, @@ -262,11 +312,11 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { unparsedModel.secrets() ) ); - executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); } else { try (onFinish) { for (FieldInferenceRequest request : requests) { - inferenceResults.get(request.index).failures.add( + inferenceResults.get(request.bulkItemIndex).failures.add( new ResourceNotFoundException( "Inference service [{}] not found for field [{}]", unparsedModel.service(), @@ -297,7 +347,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { request.field ); } - inferenceResults.get(request.index).failures.add(failure); + inferenceResults.get(request.bulkItemIndex).failures.add(failure); } } } @@ -305,18 +355,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - int currentBatchSize = Math.min(requests.size(), batchSize); - final List currentBatch = requests.subList(0, currentBatchSize); - final List nextBatch = requests.subList(currentBatchSize, requests.size()); - final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { - try { + try (onFinish) { var requestsIterator = requests.iterator(); for (ChunkedInference result : results) { var request = requestsIterator.next(); - var acc = inferenceResults.get(request.index); + var acc = inferenceResults.get(request.bulkItemIndex); if (result instanceof ChunkedInferenceError error) { acc.addFailure( new InferenceException( @@ -331,7 +378,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { new FieldInferenceResponse( request.field(), request.sourceField(), - request.input(), + useLegacyFormat ? request.input() : null, request.inputOrder(), request.offsetAdjustment(), inferenceProvider.model, @@ -340,17 +387,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { ); } } - } finally { - onFinish(); } } @Override public void onFailure(Exception exc) { - try { + try (onFinish) { for (FieldInferenceRequest request : requests) { addInferenceResponseFailure( - request.index, + request.bulkItemIndex, new InferenceException( "Exception when running inference id [{}] on field [{}]", exc, @@ -359,16 +404,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { ) ); } - } finally { - onFinish(); - } - } - - private void onFinish() { - if (nextBatch.isEmpty()) { - onFinish.close(); - } else { - executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); } } }; @@ -376,6 +411,132 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener); } + /** + * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap} + * for the specified {@code item}. + * + * @param item The bulk request item to process. + * @param itemIndex The position of the item within the original bulk request. + * @param requestsMap A map storing inference requests, where each key is an inference ID, + * and the value is a list of associated {@link FieldInferenceRequest} objects. + * @return The total content length of all newly added requests, or {@code 0} if no requests were added. + */ + private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map> requestsMap) { + boolean isUpdateRequest = false; + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + isUpdateRequest = true; + if (updateRequest.script() != null) { + addInferenceResponseFailure( + itemIndex, + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + return 0; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request + return 0; + } + + final Map docMap = indexRequest.sourceAsMap(); + long inputLength = 0; + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + + if (useLegacyFormat) { + var originalFieldValue = XContentMapValues.extractValue(field, docMap); + if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) { + // Inference has already been computed, or there is no inference required. + continue; + } + } else { + var inferenceMetadataFieldsValue = XContentMapValues.extractValue( + InferenceMetadataFieldsMapper.NAME + "." + field, + docMap, + EXPLICIT_NULL + ); + if (inferenceMetadataFieldsValue != null) { + // Inference has already been computed + continue; + } + } + + int order = 0; + for (var sourceField : entry.getSourceFields()) { + var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL); + if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) { + /** + * It's an update request, and the source field is explicitly set to null, + * so we need to propagate this information to the inference fields metadata + * to overwrite any inference previously computed on the field. + * This ensures that the field is treated as intentionally cleared, + * preventing any unintended carryover of prior inference results. + */ + var slot = ensureResponseAccumulatorSlot(itemIndex); + slot.addOrUpdateResponse( + new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) + ); + continue; + } + if (valueObj == null || valueObj == EXPLICIT_NULL) { + if (isUpdateRequest && useLegacyFormat) { + addInferenceResponseFailure( + itemIndex, + new ElasticsearchStatusException( + "Field [{}] must be specified on an update request to calculate inference for field [{}]", + RestStatus.BAD_REQUEST, + sourceField, + field + ) + ); + break; + } + continue; + } + var slot = ensureResponseAccumulatorSlot(itemIndex); + final List values; + try { + values = SemanticTextUtils.nodeStringValues(field, valueObj); + } catch (Exception exc) { + addInferenceResponseFailure(itemIndex, exc); + break; + } + + if (INFERENCE_API_FEATURE.check(licenseState) == false) { + addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); + break; + } + + List requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + int offsetAdjustment = 0; + for (String v : values) { + inputLength += v.length(); + if (v.isBlank()) { + slot.addOrUpdateResponse( + new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) + ); + } else { + requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); + } + + // When using the inference metadata fields format, all the input values are concatenated so that the + // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment + // to apply to account for this. + offsetAdjustment += v.length() + 1; // Add one for separator char length + } + } + } + return inputLength; + } + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { @@ -404,7 +565,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - var newDocMap = indexRequest.sourceAsMap(); Map inferenceFieldsMap = new HashMap<>(); for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); @@ -426,28 +586,22 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { } var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>()); - lst.addAll( - SemanticTextField.toSemanticTextFieldChunks( - resp.input, - resp.offsetAdjustment, - resp.chunkedResults, - indexRequest.getContentType(), - useLegacyFormat - ) - ); + var chunks = useLegacyFormat + ? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType()) + : toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType()); + lst.addAll(chunks); } - List inputs = responses.stream() - .filter(r -> r.sourceField().equals(fieldName)) - .map(r -> r.input) - .collect(Collectors.toList()); + List inputs = useLegacyFormat + ? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList()) + : null; // The model can be null if we are only processing update requests that clear inference results. This is ok because we will // merge in the field's existing model settings on the data node. var result = new SemanticTextField( useLegacyFormat, fieldName, - useLegacyFormat ? inputs : null, + inputs, new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), model != null ? new MinimalServiceSettings(model) : null, @@ -455,149 +609,52 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { ), indexRequest.getContentType() ); + inferenceFieldsMap.put(fieldName, result); + } - if (useLegacyFormat) { - SemanticTextUtils.insertValue(fieldName, newDocMap, result); - } else { - inferenceFieldsMap.put(fieldName, result); + if (useLegacyFormat) { + var newDocMap = indexRequest.sourceAsMap(); + for (var entry : inferenceFieldsMap.entrySet()) { + SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue()); + } + indexRequest.source(newDocMap, indexRequest.getContentType()); + } else { + try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { + appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); + indexRequest.source(builder); } } - if (useLegacyFormat == false) { - newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap); + } + } + + /** + * Appends the original source and the new inference metadata field directly to the provided + * {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}. + */ + private static void appendSourceAndInferenceMetadata( + XContentBuilder builder, + BytesReference source, + XContentType xContentType, + Map inferenceFieldsMap + ) throws IOException { + builder.startObject(); + + // append the original source + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) { + // skip start object + parser.nextToken(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + builder.copyCurrentStructure(parser); } - indexRequest.source(newDocMap, indexRequest.getContentType()); } - /** - * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. - * If results are already populated for fields in the original index request, the inference request for this specific - * field is skipped, and the existing results remain unchanged. - * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, - * where an error will be thrown if they mismatch or if the content is malformed. - *

- * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? - */ - private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { - Map> fieldRequestsMap = new LinkedHashMap<>(); - for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) { - var item = bulkShardRequest.items()[itemIndex]; - if (item.getPrimaryResponse() != null) { - // item was already aborted/processed by a filter in the chain upstream (e.g. security) - continue; - } - boolean isUpdateRequest = false; - final IndexRequest indexRequest; - if (item.request() instanceof IndexRequest ir) { - indexRequest = ir; - } else if (item.request() instanceof UpdateRequest updateRequest) { - isUpdateRequest = true; - if (updateRequest.script() != null) { - addInferenceResponseFailure( - itemIndex, - new ElasticsearchStatusException( - "Cannot apply update with a script on indices that contain [{}] field(s)", - RestStatus.BAD_REQUEST, - SemanticTextFieldMapper.CONTENT_TYPE - ) - ); - continue; - } - indexRequest = updateRequest.doc(); - } else { - // ignore delete request - continue; - } - - final Map docMap = indexRequest.sourceAsMap(); - for (var entry : fieldInferenceMap.values()) { - String field = entry.getName(); - String inferenceId = entry.getInferenceId(); - - if (useLegacyFormat) { - var originalFieldValue = XContentMapValues.extractValue(field, docMap); - if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) { - // Inference has already been computed, or there is no inference required. - continue; - } - } else { - var inferenceMetadataFieldsValue = XContentMapValues.extractValue( - InferenceMetadataFieldsMapper.NAME + "." + field, - docMap, - EXPLICIT_NULL - ); - if (inferenceMetadataFieldsValue != null) { - // Inference has already been computed - continue; - } - } - - int order = 0; - for (var sourceField : entry.getSourceFields()) { - var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL); - if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) { - /** - * It's an update request, and the source field is explicitly set to null, - * so we need to propagate this information to the inference fields metadata - * to overwrite any inference previously computed on the field. - * This ensures that the field is treated as intentionally cleared, - * preventing any unintended carryover of prior inference results. - */ - var slot = ensureResponseAccumulatorSlot(itemIndex); - slot.addOrUpdateResponse( - new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) - ); - continue; - } - if (valueObj == null || valueObj == EXPLICIT_NULL) { - if (isUpdateRequest && useLegacyFormat) { - addInferenceResponseFailure( - itemIndex, - new ElasticsearchStatusException( - "Field [{}] must be specified on an update request to calculate inference for field [{}]", - RestStatus.BAD_REQUEST, - sourceField, - field - ) - ); - break; - } - continue; - } - var slot = ensureResponseAccumulatorSlot(itemIndex); - final List values; - try { - values = SemanticTextUtils.nodeStringValues(field, valueObj); - } catch (Exception exc) { - addInferenceResponseFailure(itemIndex, exc); - break; - } - - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); - break; - } - - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); - int offsetAdjustment = 0; - for (String v : values) { - if (v.isBlank()) { - slot.addOrUpdateResponse( - new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) - ); - } else { - fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); - } - - // When using the inference metadata fields format, all the input values are concatenated so that the - // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment - // to apply to account for this. - offsetAdjustment += v.length() + 1; // Add one for separator char length - } - } - } - } - return fieldRequestsMap; + // add the inference metadata field + builder.field(InferenceMetadataFieldsMapper.NAME); + try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) { + builder.copyCurrentStructure(parser); } + + builder.endObject(); } static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index ba86a45159b0..7ed17d34ae8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -267,37 +267,38 @@ public record SemanticTextField( /** * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. */ - public static List toSemanticTextFieldChunks( - String input, - int offsetAdjustment, - ChunkedInference results, - XContentType contentType, - boolean useLegacyFormat - ) throws IOException { + public static List toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType) + throws IOException { List chunks = new ArrayList<>(); Iterator it = results.chunksAsByteReference(contentType.xContent()); while (it.hasNext()) { - chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat)); + chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next())); } return chunks; } - public static Chunk toSemanticTextFieldChunk( - String input, - int offsetAdjustment, - ChunkedInference.Chunk chunk, - boolean useLegacyFormat - ) { + /** + * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. + */ + public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) { String text = null; - int startOffset = -1; - int endOffset = -1; - if (useLegacyFormat) { - text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); - } else { - startOffset = chunk.textOffset().start() + offsetAdjustment; - endOffset = chunk.textOffset().end() + offsetAdjustment; - } - + int startOffset = chunk.textOffset().start() + offsetAdjustment; + int endOffset = chunk.textOffset().end() + offsetAdjustment; return new Chunk(text, startOffset, endOffset, chunk.bytesReference()); } + + public static List toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType) + throws IOException { + List chunks = new ArrayList<>(); + Iterator it = results.chunksAsByteReference(contentType.xContent()); + while (it.hasNext()) { + chunks.add(toSemanticTextFieldChunkLegacy(input, it.next())); + } + return chunks; + } + + public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) { + var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); + return new Chunk(text, -1, -1, chunk.bytesReference()); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 61b7e08b6fba..d50daa9f4515 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -28,7 +28,9 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.IndexVersion; @@ -66,12 +68,13 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; @@ -118,7 +121,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -144,7 +147,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -185,7 +188,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -232,7 +234,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -303,7 +304,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -374,7 +374,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), - randomIntBetween(1, 10), useLegacyFormat, true ); @@ -447,13 +446,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter( - threadPool, - inferenceModelMap, - randomIntBetween(10, 30), - useLegacyFormat, - true - ); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -487,7 +480,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { private static ShardBulkInferenceActionFilter createFilter( ThreadPool threadPool, Map modelMap, - int batchSize, boolean useLegacyFormat, boolean isLicenseValidForInference ) { @@ -554,18 +546,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { createClusterService(useLegacyFormat), inferenceServiceRegistry, modelRegistry, - licenseState, - batchSize + licenseState ); } private static ClusterService createClusterService(boolean useLegacyFormat) { IndexMetadata indexMetadata = mock(IndexMetadata.class); - var settings = Settings.builder() + var indexSettings = Settings.builder() .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current()) .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) .build(); - when(indexMetadata.getSettings()).thenReturn(settings); + when(indexMetadata.getSettings()).thenReturn(indexSettings); ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build()); when(project.index(anyString())).thenReturn(indexMetadata); @@ -576,7 +567,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build(); ClusterService clusterService = mock(ClusterService.class); when(clusterService.state()).thenReturn(clusterState); - + long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes()); + Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE))); return clusterService; } @@ -587,7 +581,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase { ) throws IOException { Map docMap = new LinkedHashMap<>(); Map expectedDocMap = new LinkedHashMap<>(); - XContentType requestContentType = randomFrom(XContentType.values()); + // force JSON to avoid double/float conversions + XContentType requestContentType = XContentType.JSON; Map inferenceMetadataFields = new HashMap<>(); for (var entry : fieldInferenceMap.values()) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index af6b398133e7..d092e9bc68e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -41,6 +41,7 @@ import java.util.function.Predicate; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -274,7 +275,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase