diff --git a/docs/changelog/123150.yaml b/docs/changelog/123150.yaml new file mode 100644 index 000000000000..d9c9072f6213 --- /dev/null +++ b/docs/changelog/123150.yaml @@ -0,0 +1,5 @@ +pr: 123150 +summary: Limit the number of chunks for semantic text to prevent high memory usage +area: Machine Learning +type: feature +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java index 3159419ad718..75aee7230be5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java @@ -17,7 +17,7 @@ import java.util.List; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; -public record ChunkedInferenceEmbedding(List chunks) implements ChunkedInference { +public record ChunkedInferenceEmbedding(List chunks) implements ChunkedInference { public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); @@ -27,10 +27,7 @@ public record ChunkedInferenceEmbedding(List c results.add( new ChunkedInferenceEmbedding( List.of( - new SparseEmbeddingResults.Chunk( - sparseEmbeddingResults.embeddings().get(i).tokens(), - new TextOffset(0, inputs.get(i).length()) - ) + new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length())) ) ) ); @@ -41,10 +38,10 @@ public record ChunkedInferenceEmbedding(List c @Override public Iterator chunksAsByteReference(XContent xcontent) throws IOException { - var asChunk = new ArrayList(); - for (var chunk : chunks()) { - asChunk.add(chunk.toChunk(xcontent)); + List chunkedInferenceChunks = new ArrayList<>(); + for (EmbeddingResults.Chunk embeddingResultsChunk : chunks()) { + chunkedInferenceChunks.add(new Chunk(embeddingResultsChunk.offset(), embeddingResultsChunk.embedding().toBytesRef(xcontent))); } - return asChunk.iterator(); + return chunkedInferenceChunks.iterator(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java index 8cd5d78a8ca9..9d889360cfc7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.inference.results; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContent; @@ -19,31 +20,30 @@ import java.util.List; * A call to the inference service may contain multiple input texts, so this results may * contain multiple results. */ -public interface EmbeddingResults> - extends - InferenceServiceResults { - - /** - * A resulting embedding together with the offset into the input text. - */ - interface Chunk { - ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException; - - ChunkedInference.TextOffset offset(); - } +public interface EmbeddingResults> extends InferenceServiceResults { /** * A resulting embedding for one of the input texts to the inference service. */ - interface Embedding { + interface Embedding> { /** - * Combines the resulting embedding with the offset into the input text into a chunk. + * Merges the existing embedding and provided embedding into a new embedding. */ - C toChunk(ChunkedInference.TextOffset offset); + E merge(E embedding); + + /** + * Serializes the embedding to bytes. + */ + BytesReference toBytesRef(XContent xContent) throws IOException; } /** * The resulting list of embeddings for the input texts to the inference service. */ List embeddings(); + + /** + * A resulting embedding together with the offset into the input text. + */ + record Chunk(Embedding embedding, ChunkedInference.TextOffset offset) {} } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java index c4001a6325fc..69665dad3d51 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -27,17 +26,17 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -public record SparseEmbeddingResults(List embeddings) - implements - EmbeddingResults { +public record SparseEmbeddingResults(List embeddings) implements EmbeddingResults { public static final String NAME = "sparse_embedding_results"; public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString(); @@ -124,7 +123,7 @@ public record SparseEmbeddingResults(List embeddings) implements Writeable, ToXContentObject, - EmbeddingResults.Embedding { + EmbeddingResults.Embedding { public static final String EMBEDDING = "embedding"; public static final String IS_TRUNCATED = "is_truncated"; @@ -175,18 +174,35 @@ public record SparseEmbeddingResults(List embeddings) } @Override - public Chunk toChunk(ChunkedInference.TextOffset offset) { - return new Chunk(tokens, offset); - } - } - - public record Chunk(List weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { - - public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException { - return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens)); + public Embedding merge(Embedding embedding) { + // This code assumes that the tokens are sorted by weight in descending order. + // If that's not the case, the resulting merged embedding will be incorrect. + List mergedTokens = new ArrayList<>(); + Set seenTokens = new HashSet<>(); + int i = 0; + int j = 0; + // TODO: maybe truncate tokens here when it's getting too large? + while (i < tokens().size() || j < embedding.tokens().size()) { + WeightedToken token; + if (i == tokens().size()) { + token = embedding.tokens().get(j++); + } else if (j == embedding.tokens().size()) { + token = tokens().get(i++); + } else if (tokens.get(i).weight() > embedding.tokens().get(j).weight()) { + token = tokens().get(i++); + } else { + token = embedding.tokens().get(j++); + } + if (seenTokens.add(token.token())) { + mergedTokens.add(token); + } + } + boolean mergedIsTruncated = isTruncated || embedding.isTruncated(); + return new Embedding(mergedTokens, mergedIsTruncated); } - private static BytesReference toBytesReference(XContent xContent, List tokens) throws IOException { + @Override + public BytesReference toBytesRef(XContent xContent) throws IOException { XContentBuilder b = XContentBuilder.builder(xContent); b.startObject(); for (var weightedToken : tokens) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java index a8f676bf41ce..15a2b8ce0f60 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java @@ -40,9 +40,11 @@ import java.util.Objects; * ] * } */ +// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the +// Embedding.merge method for bits. TODO: implement a proper merge method public record TextEmbeddingBitResults(List embeddings) implements - TextEmbeddingResults { + TextEmbeddingResults { public static final String NAME = "text_embedding_service_bit_results"; public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java index fd8f22e535ee..75f2a0268c0d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java @@ -15,7 +15,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; @@ -48,9 +47,7 @@ import java.util.Objects; * ] * } */ -public record TextEmbeddingByteResults(List embeddings) - implements - TextEmbeddingResults { +public record TextEmbeddingByteResults(List embeddings) implements TextEmbeddingResults { public static final String NAME = "text_embedding_service_byte_results"; public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; @@ -118,9 +115,20 @@ public record TextEmbeddingByteResults(List embeddings) return Objects.hash(embeddings); } - public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding { + // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging + // embeddings should happen inbetween serializations. + public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings) + implements + Writeable, + ToXContentObject, + EmbeddingResults.Embedding { + public static final String EMBEDDING = "embedding"; + public Embedding(byte[] values) { + this(values, null, 1); + } + public Embedding(StreamInput in) throws IOException { this(in.readByteArray()); } @@ -187,25 +195,26 @@ public record TextEmbeddingByteResults(List embeddings) } @Override - public Chunk toChunk(ChunkedInference.TextOffset offset) { - return new Chunk(values, offset); - } - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { - - public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException { - return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding)); + public Embedding merge(Embedding embedding) { + byte[] newValues = new byte[values.length]; + int[] newSumMergedValues = new int[values.length]; + int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings; + for (int i = 0; i < values.length; i++) { + newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i]) + + (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]); + // Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the + // closest byte instead of truncating. + newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings); + } + return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings); } - private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException { + @Override + public BytesReference toBytesRef(XContent xContent) throws IOException { XContentBuilder builder = XContentBuilder.builder(xContent); builder.startArray(); - for (byte v : value) { - builder.value(v); + for (byte value : values) { + builder.value(value); } builder.endArray(); return BytesReference.bytes(builder); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java index 8dfdf57f9d1b..b13b82b7dd39 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -53,9 +52,7 @@ import java.util.stream.Collectors; * ] * } */ -public record TextEmbeddingFloatResults(List embeddings) - implements - TextEmbeddingResults { +public record TextEmbeddingFloatResults(List embeddings) implements TextEmbeddingResults { public static final String NAME = "text_embedding_service_results"; public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); @@ -155,9 +152,19 @@ public record TextEmbeddingFloatResults(List embeddings) return Objects.hash(embeddings); } - public record Embedding(float[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding { + // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging + // embeddings should happen inbetween serializations. + public record Embedding(float[] values, int numberOfMergedEmbeddings) + implements + Writeable, + ToXContentObject, + EmbeddingResults.Embedding { public static final String EMBEDDING = "embedding"; + public Embedding(float[] values) { + this(values, 1); + } + public Embedding(StreamInput in) throws IOException { this(in.readFloatArray()); } @@ -221,25 +228,21 @@ public record TextEmbeddingFloatResults(List embeddings) } @Override - public Chunk toChunk(ChunkedInference.TextOffset offset) { - return new Chunk(values, offset); - } - } - - public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { - - public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException { - return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding)); + public Embedding merge(Embedding embedding) { + float[] mergedValues = new float[values.length]; + for (int i = 0; i < values.length; i++) { + mergedValues[i] = (numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i]) + / (numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings); + } + return new Embedding(mergedValues, numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings); } - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException { + @Override + public BytesReference toBytesRef(XContent xContent) throws IOException { XContentBuilder b = XContentBuilder.builder(xContent); b.startArray(); - for (float v : value) { - b.value(v); + for (float value : values) { + b.value(value); } b.endArray(); return BytesReference.bytes(b); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 4caeea4930fd..ea4e45ec6740 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -7,9 +7,7 @@ package org.elasticsearch.xpack.core.inference.results; -public interface TextEmbeddingResults> - extends - EmbeddingResults { +public interface TextEmbeddingResults> extends EmbeddingResults { /** * Returns the first text embedding entry in the result list's array size. diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResultsTests.java similarity index 97% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResultsTests.java index 1b9b2db660bf..909478c2e537 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResultsTests.java @@ -5,12 +5,11 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyMlTextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyMlTextEmbeddingResultsTests.java similarity index 97% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyMlTextEmbeddingResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyMlTextEmbeddingResultsTests.java index f7ed3f34d364..6251881e41b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyMlTextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/LegacyMlTextEmbeddingResultsTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; @@ -14,7 +14,6 @@ import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResultsTests.java similarity index 82% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResultsTests.java index dcdbc13f097b..2c8b8487c67e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResultsTests.java @@ -5,12 +5,11 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -20,6 +19,7 @@ import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { @@ -161,6 +161,44 @@ public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase ); } + public void testEmbeddingMerge() { + SparseEmbeddingResults.Embedding embedding1 = new SparseEmbeddingResults.Embedding( + List.of( + new WeightedToken("this", 1.0f), + new WeightedToken("is", 0.8f), + new WeightedToken("the", 0.6f), + new WeightedToken("first", 0.4f), + new WeightedToken("embedding", 0.2f) + ), + true + ); + SparseEmbeddingResults.Embedding embedding2 = new SparseEmbeddingResults.Embedding( + List.of( + new WeightedToken("this", 0.95f), + new WeightedToken("is", 0.85f), + new WeightedToken("another", 0.65f), + new WeightedToken("embedding", 0.15f) + ), + false + ); + assertThat( + embedding1.merge(embedding2), + equalTo( + new SparseEmbeddingResults.Embedding( + List.of( + new WeightedToken("this", 1.0f), + new WeightedToken("is", 0.85f), + new WeightedToken("another", 0.65f), + new WeightedToken("the", 0.6f), + new WeightedToken("first", 0.4f), + new WeightedToken("embedding", 0.2f) + ), + true + ) + ) + ); + } + public record EmbeddingExpectation(Map tokens, boolean isTruncated) {} public static Map buildExpectationSparseEmbeddings(List embeddings) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingBitResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java similarity index 96% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingBitResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java index fb3203f633ff..61b49075702a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingBitResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResultsTests.java @@ -5,13 +5,11 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java similarity index 85% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java index 945eadd67d1f..60f45399cfb3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResultsTests.java @@ -5,12 +5,11 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import java.io.IOException; @@ -18,6 +17,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { @@ -115,6 +115,16 @@ public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCa assertThat(firstEmbeddingSize, is(2)); } + public void testEmbeddingMerge() { + TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 }); + TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 }); + TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 }); + TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2); + assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 }))); + mergedEmbedding = mergedEmbedding.merge(embedding3); + assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 }))); + } + @Override protected Writeable.Reader instanceReader() { return TextEmbeddingByteResults::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java similarity index 83% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java index c56defb69382..8cdd98bcdebc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResultsTests.java @@ -5,13 +5,11 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import java.io.IOException; @@ -19,9 +17,10 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase { +public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase { public static TextEmbeddingFloatResults createRandomResults() { int embeddings = randomIntBetween(1, 10); List embeddingResults = new ArrayList<>(embeddings); @@ -116,6 +115,16 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase instanceReader() { return TextEmbeddingFloatResults::new; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index b744f540e769..2389ee45911a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -35,6 +35,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; @@ -181,8 +182,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten results.add( new ChunkedInferenceEmbedding( List.of( - new TextEmbeddingFloatResults.Chunk( - nonChunkedResults.embeddings().get(i).values(), + new EmbeddingResults.Chunk( + nonChunkedResults.embeddings().get(i), new ChunkedInference.TextOffset(0, input.get(i).length()) ) ) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 4e10ce45efea..b860bb85ebd0 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -33,6 +33,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -172,7 +173,12 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte } results.add( new ChunkedInferenceEmbedding( - List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length()))) + List.of( + new EmbeddingResults.Chunk( + new SparseEmbeddingResults.Embedding(tokens, false), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 0d018f30a8a6..13bb406ed481 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -36,7 +36,7 @@ import java.util.stream.Collectors; * processing and map the results back to the original element * in the input list. */ -public class EmbeddingRequestChunker { +public class EmbeddingRequestChunker> { // Visible for testing record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { @@ -56,13 +56,19 @@ public class EmbeddingRequestChunker { private static final int DEFAULT_WORDS_PER_CHUNK = 250; private static final int DEFAULT_CHUNK_OVERLAP = 100; - private final List inputs; - private final List> requests; + // The maximum number of chunks that is stored for any input text. + // If the configured chunker chunks the text into more chunks, each + // chunk is sent to the inference service separately, but the results + // are merged so that only this maximum number of chunks is stored. + private static final int MAX_CHUNKS = 512; + private final List batchRequests; private final AtomicInteger resultCount = new AtomicInteger(); - private final List>> results; - private final AtomicArray errors; + private final List> resultOffsetStarts; + private final List> resultOffsetEnds; + private final List> resultEmbeddings; + private final AtomicArray resultsErrors; private ActionListener> finalListener; public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { @@ -74,31 +80,41 @@ public class EmbeddingRequestChunker { } public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) { - this.inputs = inputs; - this.results = new ArrayList<>(inputs.size()); - this.errors = new AtomicArray<>(inputs.size()); + this.resultEmbeddings = new ArrayList<>(inputs.size()); + this.resultOffsetStarts = new ArrayList<>(inputs.size()); + this.resultOffsetEnds = new ArrayList<>(inputs.size()); + this.resultsErrors = new AtomicArray<>(inputs.size()); if (chunkingSettings == null) { chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); } Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); - this.requests = new ArrayList<>(inputs.size()); - + List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { List chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings); - List requestForInput = new ArrayList<>(chunks.size()); + int resultCount = Math.min(chunks.size(), MAX_CHUNKS); + resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); + resultOffsetStarts.add(new ArrayList<>(resultCount)); + resultOffsetEnds.add(new ArrayList<>(resultCount)); + for (int chunkIndex = 0; chunkIndex < chunks.size(); chunkIndex++) { - requestForInput.add(new Request(inputIndex, chunkIndex, chunks.get(chunkIndex), inputs)); + // If the number of chunks is larger than the maximum allowed value, + // scale the indices to [0, MAX) with similar number of original + // chunks in the final chunks. + int targetChunkIndex = chunks.size() <= MAX_CHUNKS ? chunkIndex : chunkIndex * MAX_CHUNKS / chunks.size(); + if (resultOffsetStarts.getLast().size() <= targetChunkIndex) { + resultOffsetStarts.getLast().add(chunks.get(chunkIndex).start()); + resultOffsetEnds.getLast().add(chunks.get(chunkIndex).end()); + } else { + resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end()); + } + allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs)); } - requests.add(requestForInput); - // size the results array with the expected number of request/responses - results.add(new AtomicReferenceArray<>(chunks.size())); } AtomicInteger counter = new AtomicInteger(); - this.batchRequests = requests.stream() - .flatMap(List::stream) + this.batchRequests = allRequests.stream() .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) .values() .stream() @@ -134,20 +150,26 @@ public class EmbeddingRequestChunker { @Override public void onResponse(InferenceServiceResults inferenceServiceResults) { - if (inferenceServiceResults instanceof EmbeddingResults embeddingResults) { - if (embeddingResults.embeddings().size() != request.requests.size()) { - onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size())); - return; - } - for (int i = 0; i < embeddingResults.embeddings().size(); i++) { - results.get(request.requests().get(i).inputIndex()) - .set(request.requests().get(i).chunkIndex(), embeddingResults.embeddings().get(i)); - } - if (resultCount.incrementAndGet() == batchRequests.size()) { - sendFinalResponse(); - } - } else { + if (inferenceServiceResults instanceof EmbeddingResults == false) { onFailure(unexpectedResultTypeException(inferenceServiceResults.getWriteableName())); + return; + } + @SuppressWarnings("unchecked") + EmbeddingResults embeddingResults = (EmbeddingResults) inferenceServiceResults; + if (embeddingResults.embeddings().size() != request.requests.size()) { + onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size())); + return; + } + for (int i = 0; i < embeddingResults.embeddings().size(); i++) { + E newEmbedding = embeddingResults.embeddings().get(i); + resultEmbeddings.get(request.requests().get(i).inputIndex()) + .updateAndGet( + request.requests().get(i).chunkIndex(), + oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding) + ); + } + if (resultCount.incrementAndGet() == batchRequests.size()) { + sendFinalResponse(); } } @@ -171,7 +193,7 @@ public class EmbeddingRequestChunker { @Override public void onFailure(Exception e) { for (Request request : request.requests) { - errors.set(request.inputIndex(), e); + resultsErrors.set(request.inputIndex(), e); } if (resultCount.incrementAndGet() == batchRequests.size()) { sendFinalResponse(); @@ -180,10 +202,10 @@ public class EmbeddingRequestChunker { } private void sendFinalResponse() { - var response = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - if (errors.get(i) != null) { - response.add(new ChunkedInferenceError(errors.get(i))); + var response = new ArrayList(resultEmbeddings.size()); + for (int i = 0; i < resultEmbeddings.size(); i++) { + if (resultsErrors.get(i) != null) { + response.add(new ChunkedInferenceError(resultsErrors.get(i))); } else { response.add(mergeResultsWithInputs(i)); } @@ -191,14 +213,15 @@ public class EmbeddingRequestChunker { finalListener.onResponse(response); } - private ChunkedInference mergeResultsWithInputs(int index) { + private ChunkedInference mergeResultsWithInputs(int inputIndex) { + List startOffsets = resultOffsetStarts.get(inputIndex); + List endOffsets = resultOffsetEnds.get(inputIndex); + AtomicReferenceArray embeddings = resultEmbeddings.get(inputIndex); + List chunks = new ArrayList<>(); - List request = requests.get(index); - AtomicReferenceArray> result = results.get(index); - for (int i = 0; i < request.size(); i++) { - EmbeddingResults.Chunk chunk = result.get(i) - .toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end())); - chunks.add(chunk); + for (int i = 0; i < embeddings.length(); i++) { + ChunkedInference.TextOffset offset = new ChunkedInference.TextOffset(startOffsets.get(i), endOffsets.get(i)); + chunks.add(new EmbeddingResults.Chunk(embeddings.get(i), offset)); } return new ChunkedInferenceEmbedding(chunks); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 7330d45b6f16..0675d75fdc5f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -741,7 +741,7 @@ public final class ServiceUtils { InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener.delegateFailureAndWrap((delegate, r) -> { - if (r instanceof TextEmbeddingResults embeddingResults) { + if (r instanceof TextEmbeddingResults embeddingResults) { try { delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index dd2b29ec3efe..c413d3b98f59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -305,7 +305,7 @@ public class AlibabaCloudSearchService extends SenderService { AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, alibabaCloudSearchModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index b9361a2e6623..1c7cac88e776 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -129,7 +129,7 @@ public class AmazonBedrockService extends SenderService { if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), maxBatchSize, baseAmazonBedrockModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index c82d0753edee..cd1f1044d068 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -121,7 +121,7 @@ public class AzureAiStudioService extends SenderService { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, baseAzureAiStudioModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 0f3e84e7c13e..874606456b80 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -281,7 +281,7 @@ public class AzureOpenAiService extends SenderService { AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model; var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, azureOpenAiModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c951a008df0f..e64c62befa9c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -284,7 +284,7 @@ public class CohereService extends SenderService { CohereModel cohereModel = (CohereModel) model; var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, cohereModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 1beb476832b2..58afc1b780e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -725,7 +725,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi if (model instanceof ElasticsearchInternalModel esModel) { - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( input, EMBEDDING_MAX_BATCH_SIZE, esModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index d9a632677a9f..c5a66ab59d14 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -325,7 +325,7 @@ public class GoogleAiStudioService extends SenderService { ) { GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, googleAiStudioModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 29f7cdee7570..2696c1be4568 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -228,7 +228,7 @@ public class GoogleVertexAiService extends SenderService { GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, googleVertexAiModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index ce1a31c90ed7..257631e016ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -127,7 +127,7 @@ public class HuggingFaceService extends HuggingFaceBaseService { var huggingFaceModel = (HuggingFaceModel) model; var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, huggingFaceModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 8009fae673a3..8f48c863d621 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -119,8 +120,8 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { results.add( new ChunkedInferenceEmbedding( List.of( - new TextEmbeddingFloatResults.Chunk( - textEmbeddingResults.embeddings().get(i).values(), + new EmbeddingResults.Chunk( + textEmbeddingResults.embeddings().get(i), new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) ) ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 599187048968..ac627dd34f8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -307,7 +307,7 @@ public class IbmWatsonxService extends SenderService { ) { IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; - var batchedRequests = new EmbeddingRequestChunker( + var batchedRequests = new EmbeddingRequestChunker<>( input.getInputs(), EMBEDDING_MAX_BATCH_SIZE, model.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 6e3c830ae764..50768df3f438 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -266,7 +266,7 @@ public class JinaAIService extends SenderService { JinaAIModel jinaaiModel = (JinaAIModel) model; var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, jinaaiModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 1f50173951db..b88c8cb16e50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -110,7 +110,7 @@ public class MistralService extends SenderService { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), MistralConstants.MAX_BATCH_SIZE, mistralEmbeddingsModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 1b29c0c97e2b..de70092696ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -335,7 +335,7 @@ public class OpenAiService extends SenderService { OpenAiModel openAiModel = (OpenAiModel) model; var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, openAiModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index 766d7436d329..c82d6c00c361 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -33,7 +33,7 @@ public class TextEmbeddingModelValidator implements ModelValidator { } private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { - if (results instanceof TextEmbeddingResults embeddingResults) { + if (results instanceof TextEmbeddingResults embeddingResults) { var serviceSettings = model.getServiceSettings(); var dimensions = serviceSettings.dimensions(); int embeddingSize = getEmbeddingSize(embeddingResults); @@ -67,7 +67,7 @@ public class TextEmbeddingModelValidator implements ModelValidator { } } - private int getEmbeddingSize(TextEmbeddingResults embeddingResults) { + private int getEmbeddingSize(TextEmbeddingResults embeddingResults) { int embeddingSize; try { embeddingSize = embeddingResults.getFirstEmbeddingSize(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 16659f075c56..acec9b638bd6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -288,7 +288,7 @@ public class VoyageAIService extends SenderService { VoyageAIModel voyageaiModel = (VoyageAIModel) model; var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); - List batchedRequests = new EmbeddingRequestChunker( + List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), getBatchSize(voyageaiModel), voyageaiModel.getConfigurations().getChunkingSettings() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index cd14d9e54507..c3f3482abb5b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -11,12 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.LegacyMlTextEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; -import org.elasticsearch.xpack.inference.results.LegacyMlTextEmbeddingResultsTests; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests; import java.io.IOException; import java.util.ArrayList; @@ -43,7 +43,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe @Override protected InferenceAction.Response createTestInstance() { var result = switch (randomIntBetween(0, 2)) { - case 0 -> TextEmbeddingResultsTests.createRandomResults(); + case 0 -> TextEmbeddingFloatResultsTests.createRandomResults(); case 1 -> LegacyMlTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults(); default -> SparseEmbeddingResultsTests.createRandomResults(); }; @@ -87,7 +87,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe } public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException { - var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); + var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults(); var instance = new InferenceAction.Response(embeddingResults); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0); assertOnBWCObject(copy, instance, V_8_12_0); @@ -103,7 +103,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe // Technically we should never see a text embedding result in the transport version of this test because support // for it wasn't added until openai public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException { - var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); + var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults(); var instance = new InferenceAction.Response(embeddingResults); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0); assertOnBWCObject(copy, instance, V_8_12_0); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index aa33cd0000b4..39d729e752c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -33,19 +34,19 @@ import static org.hamcrest.Matchers.startsWith; public class EmbeddingRequestChunkerTests extends ESTestCase { public void testEmptyInput_WordChunker() { - var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, empty()); } public void testEmptyInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker(List.of(), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners( + var batches = new EmbeddingRequestChunker<>(List.of(), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners( testListener() ); assertThat(batches, empty()); } public void testWhitespaceInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) + var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); @@ -53,30 +54,29 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { } public void testBlankInput_WordChunker() { - var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); } public void testBlankInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners( - testListener() - ); + var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); } public void testInputThatDoesNotChunk_WordChunker() { - var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { - var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) + var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); @@ -85,14 +85,14 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { public void testShortInputsAreSingleBatch() { String input = "one chunk"; - var batches = new EmbeddingRequestChunker(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), contains(input)); } public void testMultipleShortInputsAreSingleBatch() { List inputs = List.of("1st small", "2nd small", "3rd small"); - var batches = new EmbeddingRequestChunker(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); assertEquals(batch.inputs(), inputs); @@ -113,7 +113,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { inputs.add("input " + i); } - var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(4)); assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch)); @@ -148,7 +148,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { inputs.add("input " + i); } - var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) + var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(4)); assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); @@ -190,7 +190,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); - var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(2)); @@ -234,6 +234,260 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(request.chunkText(), equalTo("3rd small")); } + public void testVeryLongInput_Sparse() { + int batchSize = 5; + int chunkSize = 20; + int numberOfWordsInPassage = (chunkSize * 10000); + + var passageBuilder = new StringBuilder(); + for (int i = 0; i < numberOfWordsInPassage; i++) { + passageBuilder.append("word").append(i).append(" "); // chunk on whitespace + } + + List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + + var finalListener = testListener(); + List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) + .batchRequestsWithListeners(finalListener); + + // The very long passage is split into 10000 chunks for inference, so + // there are 10002 inference requests, resulting in 2001 batches. + assertThat(batches, hasSize(2001)); + for (int i = 0; i < 2000; i++) { + assertThat(batches.get(i).batch().inputs(), hasSize(5)); + } + assertThat(batches.get(2000).batch().inputs(), hasSize(2)); + + // Produce inference results for each request, with just the token + // "word" and increasing weights. + float weight = 0f; + for (var batch : batches) { + var embeddings = new ArrayList(); + for (int i = 0; i < batch.batch().requests().size(); i++) { + weight += 1 / 16384f; + embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("word", weight)), false)); + } + batch.listener().onResponse(new SparseEmbeddingResults(embeddings)); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(3)); + + // The first input has the token with weight 1/16384f. + ChunkedInference inference = finalListener.results.get(0); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(1)); + assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f))); + + // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for + // inference. They get the embeddings with token "word" and weights 2/1024 ... 10000/16384. + // Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks + // and therefore 380 or 400 words. For each, the max token weights are collected. + inference = finalListener.results.get(1); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(512)); + + // The first merged chunk consists of 20 small chunks (so 400 words) and the max + // weight is the weight of the 20th small chunk (so 21/16384). + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f))); + + // The last merged chunk consists of 19 small chunks (so 380 words) and the max + // weight is the weight of the 10000th small chunk (so 10001/16384). + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f))); + + // The last input has the token with weight 10002/16384. + inference = finalListener.results.get(2); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(1)); + assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f))); + } + + public void testVeryLongInput_Float() { + int batchSize = 5; + int chunkSize = 20; + int numberOfWordsInPassage = (chunkSize * 10000); + + var passageBuilder = new StringBuilder(); + for (int i = 0; i < numberOfWordsInPassage; i++) { + passageBuilder.append("word").append(i).append(" "); // chunk on whitespace + } + + List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + + var finalListener = testListener(); + List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) + .batchRequestsWithListeners(finalListener); + + // The very long passage is split into 10000 chunks for inference, so + // there are 10002 inference requests, resulting in 2001 batches. + assertThat(batches, hasSize(2001)); + for (int i = 0; i < 2000; i++) { + assertThat(batches.get(i).batch().inputs(), hasSize(5)); + } + assertThat(batches.get(2000).batch().inputs(), hasSize(2)); + + // Produce inference results for each request, with increasing weights. + float weight = 0f; + for (var batch : batches) { + var embeddings = new ArrayList(); + for (int i = 0; i < batch.batch().requests().size(); i++) { + weight += 1 / 16384f; + embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight })); + } + batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings)); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(3)); + + // The first input has the embedding with weight 1/16384. + ChunkedInference inference = finalListener.results.get(0); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(1)); + assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); + + // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for + // inference. They get the embeddings with weights 2/1024 ... 10000/16384. + // Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks + // and therefore 380 or 400 words. For each, the average weight is collected. + inference = finalListener.results.get(1); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(512)); + + // The first merged chunk consists of 20 small chunks (so 400 words) and the weight + // is the average of the weights 2/16384 ... 21/16384. + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); + + // The last merged chunk consists of 19 small chunks (so 380 words) and the weight + // is the average of the weights 9983/16384 ... 10001/16384. + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); + + // The last input has the token with weight 10002/16384. + inference = finalListener.results.get(2); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(1)); + assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); + } + + public void testVeryLongInput_Byte() { + int batchSize = 5; + int chunkSize = 20; + int numberOfWordsInPassage = (chunkSize * 10000); + + var passageBuilder = new StringBuilder(); + for (int i = 0; i < numberOfWordsInPassage; i++) { + passageBuilder.append("word").append(i).append(" "); // chunk on whitespace + } + + List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + + var finalListener = testListener(); + List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) + .batchRequestsWithListeners(finalListener); + + // The very long passage is split into 10000 chunks for inference, so + // there are 10002 inference requests, resulting in 2001 batches. + assertThat(batches, hasSize(2001)); + for (int i = 0; i < 2000; i++) { + assertThat(batches.get(i).batch().inputs(), hasSize(5)); + } + assertThat(batches.get(2000).batch().inputs(), hasSize(2)); + + // Produce inference results for each request, with increasing weights. + byte weight = 0; + for (var batch : batches) { + var embeddings = new ArrayList(); + for (int i = 0; i < batch.batch().requests().size(); i++) { + weight += 1; + embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight })); + } + batch.listener().onResponse(new TextEmbeddingByteResults(embeddings)); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(3)); + + // The first input has the embedding with weight 1. + ChunkedInference inference = finalListener.results.get(0); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(1)); + assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.values(), equalTo(new byte[] { 1 })); + + // The very long passage "word0 word1 ... word199999" is split into 10000 chunks for + // inference. They get the embeddings with weights 2/1024 ... 10000/16384. + // Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks + // and therefore 380 or 400 words. For each, the average weight is collected. + inference = finalListener.results.get(1); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(512)); + + // The first merged chunk consists of 20 small chunks (so 400 words) and the weight + // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.values(), equalTo(new byte[] { 12 })); + + // The last merged chunk consists of 19 small chunks (so 380 words) and the weight + // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so + // the average of -1, 0, 1, ... , 17, so 8. + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); + assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); + assertThat(embedding.values(), equalTo(new byte[] { 8 })); + + // The last input has the token with weight 10002 % 256 = 18 + inference = finalListener.results.get(2); + assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); + chunkedEmbedding = (ChunkedInferenceEmbedding) inference; + assertThat(chunkedEmbedding.chunks(), hasSize(1)); + assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); + assertThat(embedding.values(), equalTo(new byte[] { 18 })); + } + public void testMergingListener_Float() { int batchSize = 5; int chunkSize = 20; @@ -246,10 +500,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("a", passageBuilder.toString(), "bb", "ccc"); + List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); var finalListener = testListener(); - var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); + var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); assertThat(batches, hasSize(2)); // 4 inputs in 2 batches @@ -275,7 +529,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedFloatResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(0), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -283,26 +537,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); - assertThat(chunkedFloatResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309))); - assertThat(chunkedFloatResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629))); - assertThat(chunkedFloatResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949))); - assertThat(chunkedFloatResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269))); - assertThat(chunkedFloatResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589))); - assertThat(chunkedFloatResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675))); + assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedFloatResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(2), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedFloatResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(3), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -318,10 +572,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("a", passageBuilder.toString(), "bb", "ccc"); + List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); var finalListener = testListener(); - var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); + var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); assertThat(batches, hasSize(2)); // 4 inputs in 2 batches @@ -347,7 +601,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -355,26 +609,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309))); - assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629))); - assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949))); - assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269))); - assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589))); - assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675))); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -390,10 +644,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("a", passageBuilder.toString(), "bb", "ccc"); + List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); var finalListener = testListener(); - var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); + var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); assertThat(batches, hasSize(2)); // 4 inputs in 2 batches @@ -419,7 +673,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -427,26 +681,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309))); - assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629))); - assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949))); - assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269))); - assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589))); - assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675))); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -462,10 +716,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("a", "bb", "ccc", passageBuilder.toString()); + List inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString()); var finalListener = testListener(); - var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); + var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); assertThat(batches, hasSize(3)); // 4 inputs in 3 batches @@ -498,21 +752,21 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedSparseResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(0), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); } { var chunkedResult = finalListener.results.get(1); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedSparseResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(1), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedSparseResult.chunks().get(0).offset()); + assertThat(getMatchedText(inputs.get(2), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); } { // this is the large input split in multiple chunks @@ -520,9 +774,9 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(chunkedSparseResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 149))); - assertThat(chunkedSparseResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(149, 309))); - assertThat(chunkedSparseResult.chunks().get(8).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1350))); + assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 ")); + assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 ")); } } @@ -545,7 +799,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { } }; - var batches = new EmbeddingRequestChunker(inputs, 10, 100, 0).batchRequestsWithListeners(listener); + var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener); assertThat(batches, hasSize(1)); var embeddings = new ArrayList(); @@ -559,6 +813,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase { return new ChunkedResultsListener(); } + private static String getMatchedText(String text, ChunkedInference.TextOffset offset) { + return text.substring(offset.start(), offset.end()); + } + private static class ChunkedResultsListener implements ActionListener> { List results; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java index e37e301dc422..8f7e1594086c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -31,9 +31,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.is; public class AmazonBedrockActionCreatorTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java index b9f0281c20c8..927535e5b123 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java @@ -34,13 +34,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java index fb538cfc5d55..a1ae30ff4185 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java @@ -42,13 +42,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 210fab457de1..d499b3a0ef31 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -40,6 +40,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; @@ -47,8 +49,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 593bac820933..9be096f789cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -40,13 +40,13 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java index 33d0e7d02a38..3437231d9699 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -42,13 +42,13 @@ import java.net.URISyntaxException; import java.util.List; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 3a512de25a39..b88f3b6d8202 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -39,13 +39,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index 6e98389728f1..bf6cc96748f2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -41,12 +41,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 5bd919e7bbc1..3777acbe0903 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -45,13 +45,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java index 28e182aa2d43..c490a3291424 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java @@ -20,11 +20,11 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java index 1cdce75d3ae0..a1cc1862bf04 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java @@ -38,12 +38,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests.createModel; import static org.hamcrest.Matchers.aMapWithSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index b3ec565b3146..e5fea7e08df9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -19,13 +19,13 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; @@ -212,7 +212,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -325,7 +325,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(webServer.requests(), hasSize(2)); { @@ -383,7 +383,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java index 7ac400667a81..95ff8483afd7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java @@ -43,12 +43,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests.createModel; import static org.hamcrest.Matchers.aMapWithSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index 49f401a7e968..992313e82d8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -34,6 +34,8 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; @@ -41,8 +43,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index c5aae6423888..70d3226d75f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; @@ -52,7 +53,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.hamcrest.Matchers.containsString; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 5d732a4416a9..76c82b971a1c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -39,6 +39,7 @@ import java.io.IOException; import java.util.List; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; @@ -46,7 +47,6 @@ import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiAct import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java index 50c64468d732..f5b82f0eddd0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -36,12 +36,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java index a19956bc6bda..4320764dfcc9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -44,14 +44,14 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationBinary; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java index 6d601b4b08c5..a52da269a5d1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java @@ -33,9 +33,9 @@ import java.nio.charset.CharacterCodingException; import java.nio.charset.StandardCharsets; import java.util.List; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index a8f37aedcece..6e33d4b2615e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -31,11 +31,11 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 0c296fc5729b..cbe0e92be80a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -44,13 +44,13 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index e28d4f9608ae..f648d58d0295 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -13,16 +13,16 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentEOFException; import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; +import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 2c2e5f5d6d72..af6b398133e7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -179,13 +179,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (String input : inputs) { byte[] values = new byte[embeddingLength]; for (int j = 0; j < values.length; j++) { values[j] = randomByte(); } - chunks.add(new TextEmbeddingByteResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length()))); + chunks.add( + new EmbeddingResults.Chunk( + new TextEmbeddingByteResults.Embedding(values), + new ChunkedInference.TextOffset(0, input.length()) + ) + ); } return new ChunkedInferenceEmbedding(chunks); } @@ -195,13 +200,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (String input : inputs) { float[] values = new float[embeddingLength]; for (int j = 0; j < values.length; j++) { values[j] = randomFloat(); } - chunks.add(new TextEmbeddingFloatResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length()))); + chunks.add( + new EmbeddingResults.Chunk( + new TextEmbeddingFloatResults.Embedding(values), + new ChunkedInference.TextOffset(0, input.length()) + ) + ); } return new ChunkedInferenceEmbedding(chunks); } @@ -211,13 +221,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase inputs, boolean withFloats) { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (String input : inputs) { var tokens = new ArrayList(); for (var token : input.split("\\s+")) { tokens.add(new WeightedToken(token, withFloats ? randomFloat() : randomIntBetween(1, 255))); } - chunks.add(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.length()))); + chunks.add( + new EmbeddingResults.Chunk( + new SparseEmbeddingResults.Embedding(tokens, false), + new ChunkedInference.TextOffset(0, input.length()) + ) + ); } return new ChunkedInferenceEmbedding(chunks); } @@ -309,7 +324,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var entry : field.inference().chunks().entrySet()) { String entryField = entry.getKey(); List entryChunks = entry.getValue(); @@ -320,7 +335,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), offset); - case BYTE, BIT -> new TextEmbeddingByteResults.Chunk(byteArrayOf(values), offset); + EmbeddingResults.Embedding embedding = switch (elementType) { + case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values)); }; - chunks.add(chunk); + chunks.add(new EmbeddingResults.Chunk(embedding, offset)); } } return new ChunkedInferenceEmbedding(chunks); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index ddaea0dc3c9a..f53249d98d69 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -20,10 +20,10 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests; import java.util.EnumSet; import java.util.HashMap; @@ -953,7 +953,7 @@ public class ServiceUtilsTests extends ESTestCase { var model = mock(Model.class); when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - var textEmbedding = TextEmbeddingResultsTests.createRandomResults(); + var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(7); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index c4c6b69b117b..a23b2556876e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -41,7 +42,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderT import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 43c6422ee041..ba3f35c3eebb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -65,13 +65,13 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; -import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap; @@ -1458,10 +1458,10 @@ public class AmazonBedrockServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123F, 0.678F }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1470,10 +1470,10 @@ public class AmazonBedrockServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223F, 0.278F }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 369fdc6d4684..ac5e3e819a0d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1205,10 +1205,10 @@ public class AzureAiStudioServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1217,10 +1217,10 @@ public class AzureAiStudioServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 1.0123f, -1.0123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 837c3ec8e6e2..7c010970eed5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -63,6 +63,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -72,7 +73,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap; @@ -1355,10 +1355,10 @@ public class AzureOpenAiServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1367,10 +1367,10 @@ public class AzureOpenAiServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 1.123f, -1.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f1f8fb0140a3..877cc294fec6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -67,6 +67,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -75,7 +76,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; @@ -1468,7 +1468,7 @@ public class CohereServiceTests extends ESTestCase { assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1479,7 +1479,7 @@ public class CohereServiceTests extends ESTestCase { assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1565,16 +1565,22 @@ public class CohereServiceTests extends ESTestCase { var byteResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(byteResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset()); - assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class)); - assertArrayEquals(new byte[] { 23, -23 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding()); + assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + assertArrayEquals( + new byte[] { 23, -23 }, + ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() + ); } { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var byteResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(byteResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset()); - assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class)); - assertArrayEquals(new byte[] { 24, -24 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding()); + assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); + assertArrayEquals( + new byte[] { 24, -24 }, + ((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values() + ); } MatcherAssert.assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index acbe471a3fb2..f38f83cf8864 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -37,7 +37,9 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -48,7 +50,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; @@ -645,8 +646,11 @@ public class ElasticInferenceServiceTests extends ESTestCase { sparseResult.chunks(), is( List.of( - new SparseEmbeddingResults.Chunk( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + new EmbeddingResults.Chunk( + new SparseEmbeddingResults.Embedding( + List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + false + ), new ChunkedInference.TextOffset(0, "input text".length()) ) ) @@ -767,8 +771,11 @@ public class ElasticInferenceServiceTests extends ESTestCase { sparseResult.chunks(), is( List.of( - new SparseEmbeddingResults.Chunk( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + new EmbeddingResults.Chunk( + new SparseEmbeddingResults.Embedding( + List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + false + ), new ChunkedInference.TextOffset(0, "input text".length()) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 2bf47b06c771..2ed414a16cf2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -896,20 +896,20 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); assertThat(result1.chunks(), hasSize(1)); - assertThat(result1.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), - ((TextEmbeddingFloatResults.Chunk) result1.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(), 0.0001f ); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); assertThat(result2.chunks(), hasSize(1)); - assertThat(result2.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), - ((TextEmbeddingFloatResults.Chunk) result2.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(), 0.0001f ); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); @@ -972,18 +972,18 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); - assertThat(result1.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); + assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - ((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens() + ((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens() ); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); - assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); + assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - ((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens() + ((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens() ); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); gotResults.set(true); @@ -1044,18 +1044,18 @@ public class ElasticsearchInternalServiceTests extends ESTestCase { assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); - assertThat(result1.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); + assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - ((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens() + ((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens() ); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); - assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); + assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - ((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens() + ((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens() ); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); gotResults.set(true); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 9343d1c25f48..6cf87f3043b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -64,6 +64,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -72,7 +73,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -882,11 +882,11 @@ public class GoogleAiStudioServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -897,11 +897,11 @@ public class GoogleAiStudioServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0456f, -0.0456f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index d732f4f85f60..a09369aaa893 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -109,8 +110,8 @@ public class HuggingFaceElserServiceTests extends ESTestCase { sparseResult.chunks(), is( List.of( - new SparseEmbeddingResults.Chunk( - List.of(new WeightedToken(".", 0.13315596f)), + new EmbeddingResults.Chunk( + new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(".", 0.13315596f)), false), new ChunkedInference.TextOffset(0, "abc".length()) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 32a597aecb41..f8673edd8172 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -35,12 +35,12 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; @@ -58,13 +58,13 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -788,10 +788,10 @@ public class HuggingFaceServiceTests extends ESTestCase { var embeddingResult = (ChunkedInferenceEmbedding) result; assertThat(embeddingResult.chunks(), hasSize(1)); assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); - assertThat(embeddingResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { -0.0123f, 0.0123f }, - ((TextEmbeddingFloatResults.Chunk) embeddingResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(), 0.001f ); assertThat(webServer.requests(), hasSize(1)); @@ -842,10 +842,10 @@ public class HuggingFaceServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index d74c9a7eafd0..eeba7a6425d9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -68,6 +68,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -76,7 +77,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -734,11 +734,11 @@ public class IbmWatsonxServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0123f, -0.0123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -749,11 +749,11 @@ public class IbmWatsonxServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.0456f, -0.0456f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 392069be0190..fabcca09d3e3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -63,6 +63,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -71,7 +72,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; @@ -1833,10 +1833,10 @@ public class JinaAIServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1845,10 +1845,10 @@ public class JinaAIServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 6acafd59272e..98ecd23ad8b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -684,11 +684,11 @@ public class MistralServiceTests extends ESTestCase { assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -696,11 +696,11 @@ public class MistralServiceTests extends ESTestCase { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index a522b83a4a67..3b876f58e244 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -72,6 +72,7 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; @@ -81,7 +82,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap; @@ -1871,11 +1871,11 @@ public class OpenAiServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } @@ -1884,11 +1884,11 @@ public class OpenAiServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); - assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java index b14a1f8f3cc7..949d20c7c7ce 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.junit.Before; import org.mockito.Mock; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index d596d53ba510..10ad38e7eee5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -14,12 +14,12 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; -import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests; import org.junit.Before; import org.mockito.Mock; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 3a5fce350046..067e2d6408db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -61,6 +61,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -69,7 +70,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; @@ -1840,10 +1840,10 @@ public class VoyageAIServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset()); - assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.123f, -0.123f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().getFirst()).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); } @@ -1852,10 +1852,10 @@ public class VoyageAIServiceTests extends ESTestCase { var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset()); - assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); + assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertArrayEquals( new float[] { 0.223f, -0.223f }, - ((TextEmbeddingFloatResults.Chunk) floatResult.chunks().getFirst()).embedding(), + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), 0.0f ); }