Add max.chunks to EmbeddingRequestChunker to prevent OOM (#123150)

* add max number of chunks

* wire merge function

* implement sparse merge function

* move tests to correct package/file

* float merge function

* bytes merge function

* more accurate byte average

* spotless

* Fix/improve EmbeddingRequestChunkerTests

* Remove TODO

* remove unnecessary field

* remove Chunk generic

* add TODO

* Remove specialized chunks

* add comment

* Update docs/changelog/123150.yaml

* update changelog
This commit is contained in:
Jan Kuipers 2025-03-13 11:38:12 +01:00 committed by GitHub
parent c24f77f547
commit a503497bce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
77 changed files with 756 additions and 355 deletions

View file

@ -0,0 +1,5 @@
pr: 123150
summary: Limit the number of chunks for semantic text to prevent high memory usage
area: Machine Learning
type: feature
issues: []

View file

@ -17,7 +17,7 @@ import java.util.List;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> chunks) implements ChunkedInference { public record ChunkedInferenceEmbedding(List<EmbeddingResults.Chunk> chunks) implements ChunkedInference {
public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) { public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());
@ -27,10 +27,7 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
results.add( results.add(
new ChunkedInferenceEmbedding( new ChunkedInferenceEmbedding(
List.of( List.of(
new SparseEmbeddingResults.Chunk( new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length()))
sparseEmbeddingResults.embeddings().get(i).tokens(),
new TextOffset(0, inputs.get(i).length())
)
) )
) )
); );
@ -41,10 +38,10 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
@Override @Override
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException { public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>(); List<Chunk> chunkedInferenceChunks = new ArrayList<>();
for (var chunk : chunks()) { for (EmbeddingResults.Chunk embeddingResultsChunk : chunks()) {
asChunk.add(chunk.toChunk(xcontent)); chunkedInferenceChunks.add(new Chunk(embeddingResultsChunk.offset(), embeddingResultsChunk.embedding().toBytesRef(xcontent)));
} }
return asChunk.iterator(); return chunkedInferenceChunks.iterator();
} }
} }

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContent;
@ -19,31 +20,30 @@ import java.util.List;
* A call to the inference service may contain multiple input texts, so this results may * A call to the inference service may contain multiple input texts, so this results may
* contain multiple results. * contain multiple results.
*/ */
public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>> public interface EmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends InferenceServiceResults {
extends
InferenceServiceResults {
/**
* A resulting embedding together with the offset into the input text.
*/
interface Chunk {
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
ChunkedInference.TextOffset offset();
}
/** /**
* A resulting embedding for one of the input texts to the inference service. * A resulting embedding for one of the input texts to the inference service.
*/ */
interface Embedding<C extends Chunk> { interface Embedding<E extends Embedding<E>> {
/** /**
* Combines the resulting embedding with the offset into the input text into a chunk. * Merges the existing embedding and provided embedding into a new embedding.
*/ */
C toChunk(ChunkedInference.TextOffset offset); E merge(E embedding);
/**
* Serializes the embedding to bytes.
*/
BytesReference toBytesRef(XContent xContent) throws IOException;
} }
/** /**
* The resulting list of embeddings for the input texts to the inference service. * The resulting list of embeddings for the input texts to the inference service.
*/ */
List<E> embeddings(); List<E> embeddings();
/**
* A resulting embedding together with the offset into the input text.
*/
record Chunk(Embedding<?> embedding, ChunkedInference.TextOffset offset) {}
} }

View file

@ -14,7 +14,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -27,17 +26,17 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
public record SparseEmbeddingResults(List<Embedding> embeddings) public record SparseEmbeddingResults(List<Embedding> embeddings) implements EmbeddingResults<SparseEmbeddingResults.Embedding> {
implements
EmbeddingResults<SparseEmbeddingResults.Chunk, SparseEmbeddingResults.Embedding> {
public static final String NAME = "sparse_embedding_results"; public static final String NAME = "sparse_embedding_results";
public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString(); public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
@ -124,7 +123,7 @@ public record SparseEmbeddingResults(List<Embedding> embeddings)
implements implements
Writeable, Writeable,
ToXContentObject, ToXContentObject,
EmbeddingResults.Embedding<Chunk> { EmbeddingResults.Embedding<Embedding> {
public static final String EMBEDDING = "embedding"; public static final String EMBEDDING = "embedding";
public static final String IS_TRUNCATED = "is_truncated"; public static final String IS_TRUNCATED = "is_truncated";
@ -175,18 +174,35 @@ public record SparseEmbeddingResults(List<Embedding> embeddings)
} }
@Override @Override
public Chunk toChunk(ChunkedInference.TextOffset offset) { public Embedding merge(Embedding embedding) {
return new Chunk(tokens, offset); // This code assumes that the tokens are sorted by weight in descending order.
} // If that's not the case, the resulting merged embedding will be incorrect.
} List<WeightedToken> mergedTokens = new ArrayList<>();
Set<String> seenTokens = new HashSet<>();
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { int i = 0;
int j = 0;
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException { // TODO: maybe truncate tokens here when it's getting too large?
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens)); while (i < tokens().size() || j < embedding.tokens().size()) {
WeightedToken token;
if (i == tokens().size()) {
token = embedding.tokens().get(j++);
} else if (j == embedding.tokens().size()) {
token = tokens().get(i++);
} else if (tokens.get(i).weight() > embedding.tokens().get(j).weight()) {
token = tokens().get(i++);
} else {
token = embedding.tokens().get(j++);
}
if (seenTokens.add(token.token())) {
mergedTokens.add(token);
}
}
boolean mergedIsTruncated = isTruncated || embedding.isTruncated();
return new Embedding(mergedTokens, mergedIsTruncated);
} }
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException { @Override
public BytesReference toBytesRef(XContent xContent) throws IOException {
XContentBuilder b = XContentBuilder.builder(xContent); XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject(); b.startObject();
for (var weightedToken : tokens) { for (var weightedToken : tokens) {

View file

@ -40,9 +40,11 @@ import java.util.Objects;
* ] * ]
* } * }
*/ */
// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
// Embedding.merge method for bits. TODO: implement a proper merge method
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings) public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
implements implements
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> { TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
public static final String NAME = "text_embedding_service_bit_results"; public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits"; public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";

View file

@ -15,7 +15,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.ToXContentObject;
@ -48,9 +47,7 @@ import java.util.Objects;
* ] * ]
* } * }
*/ */
public record TextEmbeddingByteResults(List<Embedding> embeddings) public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
implements
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
public static final String NAME = "text_embedding_service_byte_results"; public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
@ -118,9 +115,20 @@ public record TextEmbeddingByteResults(List<Embedding> embeddings)
return Objects.hash(embeddings); return Objects.hash(embeddings);
} }
public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> { // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
// embeddings should happen inbetween serializations.
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
implements
Writeable,
ToXContentObject,
EmbeddingResults.Embedding<Embedding> {
public static final String EMBEDDING = "embedding"; public static final String EMBEDDING = "embedding";
public Embedding(byte[] values) {
this(values, null, 1);
}
public Embedding(StreamInput in) throws IOException { public Embedding(StreamInput in) throws IOException {
this(in.readByteArray()); this(in.readByteArray());
} }
@ -187,25 +195,26 @@ public record TextEmbeddingByteResults(List<Embedding> embeddings)
} }
@Override @Override
public Chunk toChunk(ChunkedInference.TextOffset offset) { public Embedding merge(Embedding embedding) {
return new Chunk(values, offset); byte[] newValues = new byte[values.length];
} int[] newSumMergedValues = new int[values.length];
} int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
for (int i = 0; i < values.length; i++) {
/** newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
*/ // Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { // closest byte instead of truncating.
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException { }
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding)); return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
} }
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException { @Override
public BytesReference toBytesRef(XContent xContent) throws IOException {
XContentBuilder builder = XContentBuilder.builder(xContent); XContentBuilder builder = XContentBuilder.builder(xContent);
builder.startArray(); builder.startArray();
for (byte v : value) { for (byte value : values) {
builder.value(v); builder.value(value);
} }
builder.endArray(); builder.endArray();
return BytesReference.bytes(builder); return BytesReference.bytes(builder);

View file

@ -16,7 +16,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -53,9 +52,7 @@ import java.util.stream.Collectors;
* ] * ]
* } * }
*/ */
public record TextEmbeddingFloatResults(List<Embedding> embeddings) public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
implements
TextEmbeddingResults<TextEmbeddingFloatResults.Chunk, TextEmbeddingFloatResults.Embedding> {
public static final String NAME = "text_embedding_service_results"; public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
@ -155,9 +152,19 @@ public record TextEmbeddingFloatResults(List<Embedding> embeddings)
return Objects.hash(embeddings); return Objects.hash(embeddings);
} }
public record Embedding(float[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> { // Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
// embeddings should happen inbetween serializations.
public record Embedding(float[] values, int numberOfMergedEmbeddings)
implements
Writeable,
ToXContentObject,
EmbeddingResults.Embedding<Embedding> {
public static final String EMBEDDING = "embedding"; public static final String EMBEDDING = "embedding";
public Embedding(float[] values) {
this(values, 1);
}
public Embedding(StreamInput in) throws IOException { public Embedding(StreamInput in) throws IOException {
this(in.readFloatArray()); this(in.readFloatArray());
} }
@ -221,25 +228,21 @@ public record TextEmbeddingFloatResults(List<Embedding> embeddings)
} }
@Override @Override
public Chunk toChunk(ChunkedInference.TextOffset offset) { public Embedding merge(Embedding embedding) {
return new Chunk(values, offset); float[] mergedValues = new float[values.length];
} for (int i = 0; i < values.length; i++) {
} mergedValues[i] = (numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i])
/ (numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { }
return new Embedding(mergedValues, numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
} }
/** @Override
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. public BytesReference toBytesRef(XContent xContent) throws IOException {
*/
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
XContentBuilder b = XContentBuilder.builder(xContent); XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray(); b.startArray();
for (float v : value) { for (float value : values) {
b.value(v); b.value(value);
} }
b.endArray(); b.endArray();
return BytesReference.bytes(b); return BytesReference.bytes(b);

View file

@ -7,9 +7,7 @@
package org.elasticsearch.xpack.core.inference.results; package org.elasticsearch.xpack.core.inference.results;
public interface TextEmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>> public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
extends
EmbeddingResults<C, E> {
/** /**
* Returns the first text embedding entry in the result list's array size. * Returns the first text embedding entry in the result list's array size.

View file

@ -5,12 +5,11 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;

View file

@ -5,7 +5,7 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
@ -14,7 +14,6 @@ import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;

View file

@ -5,12 +5,11 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.ml.search.WeightedToken;
@ -20,6 +19,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase<SparseEmbeddingResults> { public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase<SparseEmbeddingResults> {
@ -161,6 +161,44 @@ public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase
); );
} }
public void testEmbeddingMerge() {
SparseEmbeddingResults.Embedding embedding1 = new SparseEmbeddingResults.Embedding(
List.of(
new WeightedToken("this", 1.0f),
new WeightedToken("is", 0.8f),
new WeightedToken("the", 0.6f),
new WeightedToken("first", 0.4f),
new WeightedToken("embedding", 0.2f)
),
true
);
SparseEmbeddingResults.Embedding embedding2 = new SparseEmbeddingResults.Embedding(
List.of(
new WeightedToken("this", 0.95f),
new WeightedToken("is", 0.85f),
new WeightedToken("another", 0.65f),
new WeightedToken("embedding", 0.15f)
),
false
);
assertThat(
embedding1.merge(embedding2),
equalTo(
new SparseEmbeddingResults.Embedding(
List.of(
new WeightedToken("this", 1.0f),
new WeightedToken("is", 0.85f),
new WeightedToken("another", 0.65f),
new WeightedToken("the", 0.6f),
new WeightedToken("first", 0.4f),
new WeightedToken("embedding", 0.2f)
),
true
)
)
);
}
public record EmbeddingExpectation(Map<String, Float> tokens, boolean isTruncated) {} public record EmbeddingExpectation(Map<String, Float> tokens, boolean isTruncated) {}
public static Map<String, Object> buildExpectationSparseEmbeddings(List<EmbeddingExpectation> embeddings) { public static Map<String, Object> buildExpectationSparseEmbeddings(List<EmbeddingExpectation> embeddings) {

View file

@ -5,13 +5,11 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import java.io.IOException; import java.io.IOException;

View file

@ -5,12 +5,11 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import java.io.IOException; import java.io.IOException;
@ -18,6 +17,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingByteResults> { public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingByteResults> {
@ -115,6 +115,16 @@ public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCa
assertThat(firstEmbeddingSize, is(2)); assertThat(firstEmbeddingSize, is(2));
} }
public void testEmbeddingMerge() {
TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
mergedEmbedding = mergedEmbedding.merge(embedding3);
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
}
@Override @Override
protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() { protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() {
return TextEmbeddingByteResults::new; return TextEmbeddingByteResults::new;

View file

@ -5,13 +5,11 @@
* 2.0. * 2.0.
*/ */
package org.elasticsearch.xpack.inference.results; package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import java.io.IOException; import java.io.IOException;
@ -19,9 +17,10 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingFloatResults> { public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingFloatResults> {
public static TextEmbeddingFloatResults createRandomResults() { public static TextEmbeddingFloatResults createRandomResults() {
int embeddings = randomIntBetween(1, 10); int embeddings = randomIntBetween(1, 10);
List<TextEmbeddingFloatResults.Embedding> embeddingResults = new ArrayList<>(embeddings); List<TextEmbeddingFloatResults.Embedding> embeddingResults = new ArrayList<>(embeddings);
@ -116,6 +115,16 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<T
assertThat(firstEmbeddingSize, is(2)); assertThat(firstEmbeddingSize, is(2));
} }
public void testEmbeddingMerge() {
TextEmbeddingFloatResults.Embedding embedding1 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f });
TextEmbeddingFloatResults.Embedding embedding2 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f });
TextEmbeddingFloatResults.Embedding embedding3 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f });
TextEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f })));
mergedEmbedding = mergedEmbedding.merge(embedding3);
assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f })));
}
@Override @Override
protected Writeable.Reader<TextEmbeddingFloatResults> instanceReader() { protected Writeable.Reader<TextEmbeddingFloatResults> instanceReader() {
return TextEmbeddingFloatResults::new; return TextEmbeddingFloatResults::new;

View file

@ -35,6 +35,7 @@ import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import java.io.IOException; import java.io.IOException;
@ -181,8 +182,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
results.add( results.add(
new ChunkedInferenceEmbedding( new ChunkedInferenceEmbedding(
List.of( List.of(
new TextEmbeddingFloatResults.Chunk( new EmbeddingResults.Chunk(
nonChunkedResults.embeddings().get(i).values(), nonChunkedResults.embeddings().get(i),
new ChunkedInference.TextOffset(0, input.get(i).length()) new ChunkedInference.TextOffset(0, input.get(i).length())
) )
) )

View file

@ -33,6 +33,7 @@ import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.ml.search.WeightedToken;
@ -172,7 +173,12 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
} }
results.add( results.add(
new ChunkedInferenceEmbedding( new ChunkedInferenceEmbedding(
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length()))) List.of(
new EmbeddingResults.Chunk(
new SparseEmbeddingResults.Embedding(tokens, false),
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
) )
); );
} }

View file

@ -36,7 +36,7 @@ import java.util.stream.Collectors;
* processing and map the results back to the original element * processing and map the results back to the original element
* in the input list. * in the input list.
*/ */
public class EmbeddingRequestChunker { public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
// Visible for testing // Visible for testing
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<String> inputs) { record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<String> inputs) {
@ -56,13 +56,19 @@ public class EmbeddingRequestChunker {
private static final int DEFAULT_WORDS_PER_CHUNK = 250; private static final int DEFAULT_WORDS_PER_CHUNK = 250;
private static final int DEFAULT_CHUNK_OVERLAP = 100; private static final int DEFAULT_CHUNK_OVERLAP = 100;
private final List<String> inputs; // The maximum number of chunks that is stored for any input text.
private final List<List<Request>> requests; // If the configured chunker chunks the text into more chunks, each
// chunk is sent to the inference service separately, but the results
// are merged so that only this maximum number of chunks is stored.
private static final int MAX_CHUNKS = 512;
private final List<BatchRequest> batchRequests; private final List<BatchRequest> batchRequests;
private final AtomicInteger resultCount = new AtomicInteger(); private final AtomicInteger resultCount = new AtomicInteger();
private final List<AtomicReferenceArray<EmbeddingResults.Embedding<?>>> results; private final List<List<Integer>> resultOffsetStarts;
private final AtomicArray<Exception> errors; private final List<List<Integer>> resultOffsetEnds;
private final List<AtomicReferenceArray<E>> resultEmbeddings;
private final AtomicArray<Exception> resultsErrors;
private ActionListener<List<ChunkedInference>> finalListener; private ActionListener<List<ChunkedInference>> finalListener;
public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch) { public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch) {
@ -74,31 +80,41 @@ public class EmbeddingRequestChunker {
} }
public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) { public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) {
this.inputs = inputs; this.resultEmbeddings = new ArrayList<>(inputs.size());
this.results = new ArrayList<>(inputs.size()); this.resultOffsetStarts = new ArrayList<>(inputs.size());
this.errors = new AtomicArray<>(inputs.size()); this.resultOffsetEnds = new ArrayList<>(inputs.size());
this.resultsErrors = new AtomicArray<>(inputs.size());
if (chunkingSettings == null) { if (chunkingSettings == null) {
chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP);
} }
Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
this.requests = new ArrayList<>(inputs.size()); List<Request> allRequests = new ArrayList<>();
for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) {
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings); List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings);
List<Request> requestForInput = new ArrayList<>(chunks.size()); int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
resultOffsetStarts.add(new ArrayList<>(resultCount));
resultOffsetEnds.add(new ArrayList<>(resultCount));
for (int chunkIndex = 0; chunkIndex < chunks.size(); chunkIndex++) { for (int chunkIndex = 0; chunkIndex < chunks.size(); chunkIndex++) {
requestForInput.add(new Request(inputIndex, chunkIndex, chunks.get(chunkIndex), inputs)); // If the number of chunks is larger than the maximum allowed value,
// scale the indices to [0, MAX) with similar number of original
// chunks in the final chunks.
int targetChunkIndex = chunks.size() <= MAX_CHUNKS ? chunkIndex : chunkIndex * MAX_CHUNKS / chunks.size();
if (resultOffsetStarts.getLast().size() <= targetChunkIndex) {
resultOffsetStarts.getLast().add(chunks.get(chunkIndex).start());
resultOffsetEnds.getLast().add(chunks.get(chunkIndex).end());
} else {
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
}
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
} }
requests.add(requestForInput);
// size the results array with the expected number of request/responses
results.add(new AtomicReferenceArray<>(chunks.size()));
} }
AtomicInteger counter = new AtomicInteger(); AtomicInteger counter = new AtomicInteger();
this.batchRequests = requests.stream() this.batchRequests = allRequests.stream()
.flatMap(List::stream)
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
.values() .values()
.stream() .stream()
@ -134,20 +150,26 @@ public class EmbeddingRequestChunker {
@Override @Override
public void onResponse(InferenceServiceResults inferenceServiceResults) { public void onResponse(InferenceServiceResults inferenceServiceResults) {
if (inferenceServiceResults instanceof EmbeddingResults<?, ?> embeddingResults) { if (inferenceServiceResults instanceof EmbeddingResults<?> == false) {
if (embeddingResults.embeddings().size() != request.requests.size()) {
onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size()));
return;
}
for (int i = 0; i < embeddingResults.embeddings().size(); i++) {
results.get(request.requests().get(i).inputIndex())
.set(request.requests().get(i).chunkIndex(), embeddingResults.embeddings().get(i));
}
if (resultCount.incrementAndGet() == batchRequests.size()) {
sendFinalResponse();
}
} else {
onFailure(unexpectedResultTypeException(inferenceServiceResults.getWriteableName())); onFailure(unexpectedResultTypeException(inferenceServiceResults.getWriteableName()));
return;
}
@SuppressWarnings("unchecked")
EmbeddingResults<E> embeddingResults = (EmbeddingResults<E>) inferenceServiceResults;
if (embeddingResults.embeddings().size() != request.requests.size()) {
onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size()));
return;
}
for (int i = 0; i < embeddingResults.embeddings().size(); i++) {
E newEmbedding = embeddingResults.embeddings().get(i);
resultEmbeddings.get(request.requests().get(i).inputIndex())
.updateAndGet(
request.requests().get(i).chunkIndex(),
oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding)
);
}
if (resultCount.incrementAndGet() == batchRequests.size()) {
sendFinalResponse();
} }
} }
@ -171,7 +193,7 @@ public class EmbeddingRequestChunker {
@Override @Override
public void onFailure(Exception e) { public void onFailure(Exception e) {
for (Request request : request.requests) { for (Request request : request.requests) {
errors.set(request.inputIndex(), e); resultsErrors.set(request.inputIndex(), e);
} }
if (resultCount.incrementAndGet() == batchRequests.size()) { if (resultCount.incrementAndGet() == batchRequests.size()) {
sendFinalResponse(); sendFinalResponse();
@ -180,10 +202,10 @@ public class EmbeddingRequestChunker {
} }
private void sendFinalResponse() { private void sendFinalResponse() {
var response = new ArrayList<ChunkedInference>(inputs.size()); var response = new ArrayList<ChunkedInference>(resultEmbeddings.size());
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < resultEmbeddings.size(); i++) {
if (errors.get(i) != null) { if (resultsErrors.get(i) != null) {
response.add(new ChunkedInferenceError(errors.get(i))); response.add(new ChunkedInferenceError(resultsErrors.get(i)));
} else { } else {
response.add(mergeResultsWithInputs(i)); response.add(mergeResultsWithInputs(i));
} }
@ -191,14 +213,15 @@ public class EmbeddingRequestChunker {
finalListener.onResponse(response); finalListener.onResponse(response);
} }
private ChunkedInference mergeResultsWithInputs(int index) { private ChunkedInference mergeResultsWithInputs(int inputIndex) {
List<Integer> startOffsets = resultOffsetStarts.get(inputIndex);
List<Integer> endOffsets = resultOffsetEnds.get(inputIndex);
AtomicReferenceArray<E> embeddings = resultEmbeddings.get(inputIndex);
List<EmbeddingResults.Chunk> chunks = new ArrayList<>(); List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
List<Request> request = requests.get(index); for (int i = 0; i < embeddings.length(); i++) {
AtomicReferenceArray<EmbeddingResults.Embedding<?>> result = results.get(index); ChunkedInference.TextOffset offset = new ChunkedInference.TextOffset(startOffsets.get(i), endOffsets.get(i));
for (int i = 0; i < request.size(); i++) { chunks.add(new EmbeddingResults.Chunk(embeddings.get(i), offset));
EmbeddingResults.Chunk chunk = result.get(i)
.toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end()));
chunks.add(chunk);
} }
return new ChunkedInferenceEmbedding(chunks); return new ChunkedInferenceEmbedding(chunks);
} }

View file

@ -741,7 +741,7 @@ public final class ServiceUtils {
InputType.INGEST, InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT, InferenceAction.Request.DEFAULT_TIMEOUT,
listener.delegateFailureAndWrap((delegate, r) -> { listener.delegateFailureAndWrap((delegate, r) -> {
if (r instanceof TextEmbeddingResults<?, ?> embeddingResults) { if (r instanceof TextEmbeddingResults<?> embeddingResults) {
try { try {
delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
} catch (Exception e) { } catch (Exception e) {

View file

@ -305,7 +305,7 @@ public class AlibabaCloudSearchService extends SenderService {
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
alibabaCloudSearchModel.getConfigurations().getChunkingSettings() alibabaCloudSearchModel.getConfigurations().getChunkingSettings()

View file

@ -129,7 +129,7 @@ public class AmazonBedrockService extends SenderService {
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
maxBatchSize, maxBatchSize,
baseAmazonBedrockModel.getConfigurations().getChunkingSettings() baseAmazonBedrockModel.getConfigurations().getChunkingSettings()

View file

@ -121,7 +121,7 @@ public class AzureAiStudioService extends SenderService {
if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
baseAzureAiStudioModel.getConfigurations().getChunkingSettings() baseAzureAiStudioModel.getConfigurations().getChunkingSettings()

View file

@ -281,7 +281,7 @@ public class AzureOpenAiService extends SenderService {
AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model; AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
azureOpenAiModel.getConfigurations().getChunkingSettings() azureOpenAiModel.getConfigurations().getChunkingSettings()

View file

@ -284,7 +284,7 @@ public class CohereService extends SenderService {
CohereModel cohereModel = (CohereModel) model; CohereModel cohereModel = (CohereModel) model;
var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
cohereModel.getConfigurations().getChunkingSettings() cohereModel.getConfigurations().getChunkingSettings()

View file

@ -725,7 +725,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
if (model instanceof ElasticsearchInternalModel esModel) { if (model instanceof ElasticsearchInternalModel esModel) {
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
input, input,
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
esModel.getConfigurations().getChunkingSettings() esModel.getConfigurations().getChunkingSettings()

View file

@ -325,7 +325,7 @@ public class GoogleAiStudioService extends SenderService {
) { ) {
GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
googleAiStudioModel.getConfigurations().getChunkingSettings() googleAiStudioModel.getConfigurations().getChunkingSettings()

View file

@ -228,7 +228,7 @@ public class GoogleVertexAiService extends SenderService {
GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;
var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
googleVertexAiModel.getConfigurations().getChunkingSettings() googleVertexAiModel.getConfigurations().getChunkingSettings()

View file

@ -127,7 +127,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
var huggingFaceModel = (HuggingFaceModel) model; var huggingFaceModel = (HuggingFaceModel) model;
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
huggingFaceModel.getConfigurations().getChunkingSettings() huggingFaceModel.getConfigurations().getChunkingSettings()

View file

@ -27,6 +27,7 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
@ -119,8 +120,8 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
results.add( results.add(
new ChunkedInferenceEmbedding( new ChunkedInferenceEmbedding(
List.of( List.of(
new TextEmbeddingFloatResults.Chunk( new EmbeddingResults.Chunk(
textEmbeddingResults.embeddings().get(i).values(), textEmbeddingResults.embeddings().get(i),
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
) )
) )

View file

@ -307,7 +307,7 @@ public class IbmWatsonxService extends SenderService {
) { ) {
IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model;
var batchedRequests = new EmbeddingRequestChunker( var batchedRequests = new EmbeddingRequestChunker<>(
input.getInputs(), input.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
model.getConfigurations().getChunkingSettings() model.getConfigurations().getChunkingSettings()

View file

@ -266,7 +266,7 @@ public class JinaAIService extends SenderService {
JinaAIModel jinaaiModel = (JinaAIModel) model; JinaAIModel jinaaiModel = (JinaAIModel) model;
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
jinaaiModel.getConfigurations().getChunkingSettings() jinaaiModel.getConfigurations().getChunkingSettings()

View file

@ -110,7 +110,7 @@ public class MistralService extends SenderService {
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
MistralConstants.MAX_BATCH_SIZE, MistralConstants.MAX_BATCH_SIZE,
mistralEmbeddingsModel.getConfigurations().getChunkingSettings() mistralEmbeddingsModel.getConfigurations().getChunkingSettings()

View file

@ -335,7 +335,7 @@ public class OpenAiService extends SenderService {
OpenAiModel openAiModel = (OpenAiModel) model; OpenAiModel openAiModel = (OpenAiModel) model;
var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents()); var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE, EMBEDDING_MAX_BATCH_SIZE,
openAiModel.getConfigurations().getChunkingSettings() openAiModel.getConfigurations().getChunkingSettings()

View file

@ -33,7 +33,7 @@ public class TextEmbeddingModelValidator implements ModelValidator {
} }
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
if (results instanceof TextEmbeddingResults<?, ?> embeddingResults) { if (results instanceof TextEmbeddingResults<?> embeddingResults) {
var serviceSettings = model.getServiceSettings(); var serviceSettings = model.getServiceSettings();
var dimensions = serviceSettings.dimensions(); var dimensions = serviceSettings.dimensions();
int embeddingSize = getEmbeddingSize(embeddingResults); int embeddingSize = getEmbeddingSize(embeddingResults);
@ -67,7 +67,7 @@ public class TextEmbeddingModelValidator implements ModelValidator {
} }
} }
private int getEmbeddingSize(TextEmbeddingResults<?, ?> embeddingResults) { private int getEmbeddingSize(TextEmbeddingResults<?> embeddingResults) {
int embeddingSize; int embeddingSize;
try { try {
embeddingSize = embeddingResults.getFirstEmbeddingSize(); embeddingSize = embeddingResults.getFirstEmbeddingSize();

View file

@ -288,7 +288,7 @@ public class VoyageAIService extends SenderService {
VoyageAIModel voyageaiModel = (VoyageAIModel) model; VoyageAIModel voyageaiModel = (VoyageAIModel) model;
var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker( List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(), inputs.getInputs(),
getBatchSize(voyageaiModel), getBatchSize(voyageaiModel),
voyageaiModel.getConfigurations().getChunkingSettings() voyageaiModel.getConfigurations().getChunkingSettings()

View file

@ -11,12 +11,12 @@ import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.LegacyMlTextEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
import org.elasticsearch.xpack.inference.results.LegacyMlTextEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -43,7 +43,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe
@Override @Override
protected InferenceAction.Response createTestInstance() { protected InferenceAction.Response createTestInstance() {
var result = switch (randomIntBetween(0, 2)) { var result = switch (randomIntBetween(0, 2)) {
case 0 -> TextEmbeddingResultsTests.createRandomResults(); case 0 -> TextEmbeddingFloatResultsTests.createRandomResults();
case 1 -> LegacyMlTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults(); case 1 -> LegacyMlTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults();
default -> SparseEmbeddingResultsTests.createRandomResults(); default -> SparseEmbeddingResultsTests.createRandomResults();
}; };
@ -87,7 +87,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe
} }
public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException { public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException {
var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults); var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0); assertOnBWCObject(copy, instance, V_8_12_0);
@ -103,7 +103,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe
// Technically we should never see a text embedding result in the transport version of this test because support // Technically we should never see a text embedding result in the transport version of this test because support
// for it wasn't added until openai // for it wasn't added until openai
public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException { public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException {
var embeddingResults = TextEmbeddingResultsTests.createRandomResults(); var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults); var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0); var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0); assertOnBWCObject(copy, instance, V_8_12_0);

View file

@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
@ -33,19 +34,19 @@ import static org.hamcrest.Matchers.startsWith;
public class EmbeddingRequestChunkerTests extends ESTestCase { public class EmbeddingRequestChunkerTests extends ESTestCase {
public void testEmptyInput_WordChunker() { public void testEmptyInput_WordChunker() {
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, empty()); assertThat(batches, empty());
} }
public void testEmptyInput_SentenceChunker() { public void testEmptyInput_SentenceChunker() {
var batches = new EmbeddingRequestChunker(List.of(), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners( var batches = new EmbeddingRequestChunker<>(List.of(), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners(
testListener() testListener()
); );
assertThat(batches, empty()); assertThat(batches, empty());
} }
public void testWhitespaceInput_SentenceChunker() { public void testWhitespaceInput_SentenceChunker() {
var batches = new EmbeddingRequestChunker(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener()); .batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1));
@ -53,30 +54,29 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
} }
public void testBlankInput_WordChunker() { public void testBlankInput_WordChunker() {
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
} }
public void testBlankInput_SentenceChunker() { public void testBlankInput_SentenceChunker() {
var batches = new EmbeddingRequestChunker(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners( var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1))
testListener() .batchRequestsWithListeners(testListener());
);
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
} }
public void testInputThatDoesNotChunk_WordChunker() { public void testInputThatDoesNotChunk_WordChunker() {
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
} }
public void testInputThatDoesNotChunk_SentenceChunker() { public void testInputThatDoesNotChunk_SentenceChunker() {
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener()); .batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1));
@ -85,14 +85,14 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
public void testShortInputsAreSingleBatch() { public void testShortInputsAreSingleBatch() {
String input = "one chunk"; String input = "one chunk";
var batches = new EmbeddingRequestChunker(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), contains(input)); assertThat(batches.get(0).batch().inputs(), contains(input));
} }
public void testMultipleShortInputsAreSingleBatch() { public void testMultipleShortInputsAreSingleBatch() {
List<String> inputs = List.of("1st small", "2nd small", "3rd small"); List<String> inputs = List.of("1st small", "2nd small", "3rd small");
var batches = new EmbeddingRequestChunker(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch();
assertEquals(batch.inputs(), inputs); assertEquals(batch.inputs(), inputs);
@ -113,7 +113,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
inputs.add("input " + i); inputs.add("input " + i);
} }
var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(4)); assertThat(batches, hasSize(4));
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
@ -148,7 +148,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
inputs.add("input " + i); inputs.add("input " + i);
} }
var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings())
.batchRequestsWithListeners(testListener()); .batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(4)); assertThat(batches, hasSize(4));
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
@ -190,7 +190,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(2)); assertThat(batches, hasSize(2));
@ -234,6 +234,260 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(request.chunkText(), equalTo("3rd small")); assertThat(request.chunkText(), equalTo("3rd small"));
} }
public void testVeryLongInput_Sparse() {
int batchSize = 5;
int chunkSize = 20;
int numberOfWordsInPassage = (chunkSize * 10000);
var passageBuilder = new StringBuilder();
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
var finalListener = testListener();
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
.batchRequestsWithListeners(finalListener);
// The very long passage is split into 10000 chunks for inference, so
// there are 10002 inference requests, resulting in 2001 batches.
assertThat(batches, hasSize(2001));
for (int i = 0; i < 2000; i++) {
assertThat(batches.get(i).batch().inputs(), hasSize(5));
}
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
// Produce inference results for each request, with just the token
// "word" and increasing weights.
float weight = 0f;
for (var batch : batches) {
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1 / 16384f;
embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("word", weight)), false));
}
batch.listener().onResponse(new SparseEmbeddingResults(embeddings));
}
assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(3));
// The first input has the token with weight 1/16384f.
ChunkedInference inference = finalListener.results.get(0);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f)));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
// inference. They get the embeddings with token "word" and weights 2/1024 ... 10000/16384.
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
// and therefore 380 or 400 words. For each, the max token weights are collected.
inference = finalListener.results.get(1);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(512));
// The first merged chunk consists of 20 small chunks (so 400 words) and the max
// weight is the weight of the 20th small chunk (so 21/16384).
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f)));
// The last merged chunk consists of 19 small chunks (so 380 words) and the max
// weight is the weight of the 10000th small chunk (so 10001/16384).
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 "));
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f)));
// The last input has the token with weight 10002/16384.
inference = finalListener.results.get(2);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f)));
}
public void testVeryLongInput_Float() {
int batchSize = 5;
int chunkSize = 20;
int numberOfWordsInPassage = (chunkSize * 10000);
var passageBuilder = new StringBuilder();
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
var finalListener = testListener();
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
.batchRequestsWithListeners(finalListener);
// The very long passage is split into 10000 chunks for inference, so
// there are 10002 inference requests, resulting in 2001 batches.
assertThat(batches, hasSize(2001));
for (int i = 0; i < 2000; i++) {
assertThat(batches.get(i).batch().inputs(), hasSize(5));
}
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
// Produce inference results for each request, with increasing weights.
float weight = 0f;
for (var batch : batches) {
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1 / 16384f;
embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight }));
}
batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings));
}
assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(3));
// The first input has the embedding with weight 1/16384.
ChunkedInference inference = finalListener.results.get(0);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f }));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
// inference. They get the embeddings with weights 2/1024 ... 10000/16384.
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
// and therefore 380 or 400 words. For each, the average weight is collected.
inference = finalListener.results.get(1);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(512));
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
// is the average of the weights 2/16384 ... 21/16384.
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) }));
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
// is the average of the weights 9983/16384 ... 10001/16384.
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 "));
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) }));
// The last input has the token with weight 10002/16384.
inference = finalListener.results.get(2);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f }));
}
public void testVeryLongInput_Byte() {
int batchSize = 5;
int chunkSize = 20;
int numberOfWordsInPassage = (chunkSize * 10000);
var passageBuilder = new StringBuilder();
for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
}
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
var finalListener = testListener();
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
.batchRequestsWithListeners(finalListener);
// The very long passage is split into 10000 chunks for inference, so
// there are 10002 inference requests, resulting in 2001 batches.
assertThat(batches, hasSize(2001));
for (int i = 0; i < 2000; i++) {
assertThat(batches.get(i).batch().inputs(), hasSize(5));
}
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
// Produce inference results for each request, with increasing weights.
byte weight = 0;
for (var batch : batches) {
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
for (int i = 0; i < batch.batch().requests().size(); i++) {
weight += 1;
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight }));
}
batch.listener().onResponse(new TextEmbeddingByteResults(embeddings));
}
assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(3));
// The first input has the embedding with weight 1.
ChunkedInference inference = finalListener.results.get(0);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 1 }));
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
// inference. They get the embeddings with weights 2/1024 ... 10000/16384.
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
// and therefore 380 or 400 words. For each, the average weight is collected.
inference = finalListener.results.get(1);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(512));
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
// is the average of the weights 2 ... 21, so 11.5, which is rounded to 12.
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 12 }));
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
// is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so
// the average of -1, 0, 1, ... , 17, so 8.
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 "));
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 8 }));
// The last input has the token with weight 10002 % 256 = 18
inference = finalListener.results.get(2);
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
assertThat(chunkedEmbedding.chunks(), hasSize(1));
assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
assertThat(embedding.values(), equalTo(new byte[] { 18 }));
}
public void testMergingListener_Float() { public void testMergingListener_Float() {
int batchSize = 5; int batchSize = 5;
int chunkSize = 20; int chunkSize = 20;
@ -246,10 +500,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) { for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
} }
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc"); List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
var finalListener = testListener(); var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
assertThat(batches, hasSize(2)); assertThat(batches, hasSize(2));
// 4 inputs in 2 batches // 4 inputs in 2 batches
@ -275,7 +529,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertThat(chunkedFloatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedFloatResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(0), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small"));
} }
{ {
// this is the large input split in multiple chunks // this is the large input split in multiple chunks
@ -283,26 +537,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(6)); assertThat(chunkedFloatResult.chunks(), hasSize(6));
assertThat(chunkedFloatResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309))); assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 "));
assertThat(chunkedFloatResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629))); assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 "));
assertThat(chunkedFloatResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949))); assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 "));
assertThat(chunkedFloatResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269))); assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 "));
assertThat(chunkedFloatResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589))); assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 "));
assertThat(chunkedFloatResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675))); assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 "));
} }
{ {
var chunkedResult = finalListener.results.get(2); var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertThat(chunkedFloatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedFloatResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(2), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small"));
} }
{ {
var chunkedResult = finalListener.results.get(3); var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertThat(chunkedFloatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedFloatResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(3), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small"));
} }
} }
@ -318,10 +572,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) { for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
} }
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc"); List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
var finalListener = testListener(); var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
assertThat(batches, hasSize(2)); assertThat(batches, hasSize(2));
// 4 inputs in 2 batches // 4 inputs in 2 batches
@ -347,7 +601,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1)); assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small"));
} }
{ {
// this is the large input split in multiple chunks // this is the large input split in multiple chunks
@ -355,26 +609,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(6)); assertThat(chunkedByteResult.chunks(), hasSize(6));
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 "));
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 "));
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 "));
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 "));
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 "));
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 "));
} }
{ {
var chunkedResult = finalListener.results.get(2); var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1)); assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small"));
} }
{ {
var chunkedResult = finalListener.results.get(3); var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1)); assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small"));
} }
} }
@ -390,10 +644,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) { for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
} }
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc"); List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
var finalListener = testListener(); var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
assertThat(batches, hasSize(2)); assertThat(batches, hasSize(2));
// 4 inputs in 2 batches // 4 inputs in 2 batches
@ -419,7 +673,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1)); assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small"));
} }
{ {
// this is the large input split in multiple chunks // this is the large input split in multiple chunks
@ -427,26 +681,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(6)); assertThat(chunkedByteResult.chunks(), hasSize(6));
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 "));
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 "));
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 "));
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 "));
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 "));
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675))); assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 "));
} }
{ {
var chunkedResult = finalListener.results.get(2); var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1)); assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small"));
} }
{ {
var chunkedResult = finalListener.results.get(3); var chunkedResult = finalListener.results.get(3);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedByteResult.chunks(), hasSize(1)); assertThat(chunkedByteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small"));
} }
} }
@ -462,10 +716,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
for (int i = 0; i < numberOfWordsInPassage; i++) { for (int i = 0; i < numberOfWordsInPassage; i++) {
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
} }
List<String> inputs = List.of("a", "bb", "ccc", passageBuilder.toString()); List<String> inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString());
var finalListener = testListener(); var finalListener = testListener();
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
assertThat(batches, hasSize(3)); assertThat(batches, hasSize(3));
// 4 inputs in 3 batches // 4 inputs in 3 batches
@ -498,21 +752,21 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(1)); assertThat(chunkedSparseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedSparseResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(0), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small"));
} }
{ {
var chunkedResult = finalListener.results.get(1); var chunkedResult = finalListener.results.get(1);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(1)); assertThat(chunkedSparseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedSparseResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(1), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small"));
} }
{ {
var chunkedResult = finalListener.results.get(2); var chunkedResult = finalListener.results.get(2);
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(1)); assertThat(chunkedSparseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedSparseResult.chunks().get(0).offset()); assertThat(getMatchedText(inputs.get(2), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small"));
} }
{ {
// this is the large input split in multiple chunks // this is the large input split in multiple chunks
@ -520,9 +774,9 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each
assertThat(chunkedSparseResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 149))); assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 "));
assertThat(chunkedSparseResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(149, 309))); assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 "));
assertThat(chunkedSparseResult.chunks().get(8).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1350))); assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 "));
} }
} }
@ -545,7 +799,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
} }
}; };
var batches = new EmbeddingRequestChunker(inputs, 10, 100, 0).batchRequestsWithListeners(listener); var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener);
assertThat(batches, hasSize(1)); assertThat(batches, hasSize(1));
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>(); var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
@ -559,6 +813,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
return new ChunkedResultsListener(); return new ChunkedResultsListener();
} }
private static String getMatchedText(String text, ChunkedInference.TextOffset offset) {
return text.substring(offset.start(), offset.end());
}
private static class ChunkedResultsListener implements ActionListener<List<ChunkedInference>> { private static class ChunkedResultsListener implements ActionListener<List<ChunkedInference>> {
List<ChunkedInference> results; List<ChunkedInference> results;

View file

@ -31,9 +31,9 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
public class AmazonBedrockActionCreatorTests extends ESTestCase { public class AmazonBedrockActionCreatorTests extends ESTestCase {

View file

@ -34,13 +34,13 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -42,13 +42,13 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;

View file

@ -40,6 +40,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
@ -47,8 +49,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;

View file

@ -40,13 +40,13 @@ import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;

View file

@ -42,13 +42,13 @@ import java.net.URISyntaxException;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -39,13 +39,13 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -41,12 +41,12 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -45,13 +45,13 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -20,11 +20,11 @@ import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.junit.After; import org.junit.After;

View file

@ -38,12 +38,12 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.aMapWithSize;

View file

@ -19,13 +19,13 @@ import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
@ -212,7 +212,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT); var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
assertThat(webServer.requests(), hasSize(1)); assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery()); assertNull(webServer.requests().get(0).getUri().getQuery());
@ -325,7 +325,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT); var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
assertThat(webServer.requests(), hasSize(2)); assertThat(webServer.requests(), hasSize(2));
{ {
@ -383,7 +383,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT); var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
assertThat(webServer.requests(), hasSize(1)); assertThat(webServer.requests(), hasSize(1));

View file

@ -43,12 +43,12 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.aMapWithSize;

View file

@ -34,6 +34,8 @@ import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
@ -41,8 +43,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap;

View file

@ -43,6 +43,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
@ -52,7 +53,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;

View file

@ -39,6 +39,7 @@ import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
@ -46,7 +47,6 @@ import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiAct
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;

View file

@ -36,12 +36,12 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -44,14 +44,14 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationBinary;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;

View file

@ -33,9 +33,9 @@ import java.nio.charset.CharacterCodingException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;

View file

@ -31,11 +31,11 @@ import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT; import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;

View file

@ -44,13 +44,13 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;

View file

@ -13,16 +13,16 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentEOFException; import org.elasticsearch.xcontent.XContentEOFException;
import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;

View file

@ -179,13 +179,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());
assert elementType == DenseVectorFieldMapper.ElementType.BYTE || elementType == DenseVectorFieldMapper.ElementType.BIT; assert elementType == DenseVectorFieldMapper.ElementType.BYTE || elementType == DenseVectorFieldMapper.ElementType.BIT;
List<TextEmbeddingByteResults.Chunk> chunks = new ArrayList<>(); List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
for (String input : inputs) { for (String input : inputs) {
byte[] values = new byte[embeddingLength]; byte[] values = new byte[embeddingLength];
for (int j = 0; j < values.length; j++) { for (int j = 0; j < values.length; j++) {
values[j] = randomByte(); values[j] = randomByte();
} }
chunks.add(new TextEmbeddingByteResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length()))); chunks.add(
new EmbeddingResults.Chunk(
new TextEmbeddingByteResults.Embedding(values),
new ChunkedInference.TextOffset(0, input.length())
)
);
} }
return new ChunkedInferenceEmbedding(chunks); return new ChunkedInferenceEmbedding(chunks);
} }
@ -195,13 +200,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());
assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; assert elementType == DenseVectorFieldMapper.ElementType.FLOAT;
List<TextEmbeddingFloatResults.Chunk> chunks = new ArrayList<>(); List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
for (String input : inputs) { for (String input : inputs) {
float[] values = new float[embeddingLength]; float[] values = new float[embeddingLength];
for (int j = 0; j < values.length; j++) { for (int j = 0; j < values.length; j++) {
values[j] = randomFloat(); values[j] = randomFloat();
} }
chunks.add(new TextEmbeddingFloatResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length()))); chunks.add(
new EmbeddingResults.Chunk(
new TextEmbeddingFloatResults.Embedding(values),
new ChunkedInference.TextOffset(0, input.length())
)
);
} }
return new ChunkedInferenceEmbedding(chunks); return new ChunkedInferenceEmbedding(chunks);
} }
@ -211,13 +221,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
} }
public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingSparse(List<String> inputs, boolean withFloats) { public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingSparse(List<String> inputs, boolean withFloats) {
List<SparseEmbeddingResults.Chunk> chunks = new ArrayList<>(); List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
for (String input : inputs) { for (String input : inputs) {
var tokens = new ArrayList<WeightedToken>(); var tokens = new ArrayList<WeightedToken>();
for (var token : input.split("\\s+")) { for (var token : input.split("\\s+")) {
tokens.add(new WeightedToken(token, withFloats ? randomFloat() : randomIntBetween(1, 255))); tokens.add(new WeightedToken(token, withFloats ? randomFloat() : randomIntBetween(1, 255)));
} }
chunks.add(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.length()))); chunks.add(
new EmbeddingResults.Chunk(
new SparseEmbeddingResults.Embedding(tokens, false),
new ChunkedInference.TextOffset(0, input.length())
)
);
} }
return new ChunkedInferenceEmbedding(chunks); return new ChunkedInferenceEmbedding(chunks);
} }
@ -309,7 +324,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
) { ) {
switch (field.inference().modelSettings().taskType()) { switch (field.inference().modelSettings().taskType()) {
case SPARSE_EMBEDDING -> { case SPARSE_EMBEDDING -> {
List<SparseEmbeddingResults.Chunk> chunks = new ArrayList<>(); List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
for (var entry : field.inference().chunks().entrySet()) { for (var entry : field.inference().chunks().entrySet()) {
String entryField = entry.getKey(); String entryField = entry.getKey();
List<SemanticTextField.Chunk> entryChunks = entry.getValue(); List<SemanticTextField.Chunk> entryChunks = entry.getValue();
@ -320,7 +335,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
String matchedText = matchedTextIt.next(); String matchedText = matchedTextIt.next();
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, chunk, matchedText); ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, chunk, matchedText);
var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType());
chunks.add(new SparseEmbeddingResults.Chunk(tokens, offset)); chunks.add(new EmbeddingResults.Chunk(new SparseEmbeddingResults.Embedding(tokens, false), offset));
} }
} }
return new ChunkedInferenceEmbedding(chunks); return new ChunkedInferenceEmbedding(chunks);
@ -343,11 +358,11 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
String matchedText = matchedTextIt.next(); String matchedText = matchedTextIt.next();
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText);
double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType());
EmbeddingResults.Chunk chunk = switch (elementType) { EmbeddingResults.Embedding<?> embedding = switch (elementType) {
case FLOAT -> new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), offset); case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
case BYTE, BIT -> new TextEmbeddingByteResults.Chunk(byteArrayOf(values), offset); case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values));
}; };
chunks.add(chunk); chunks.add(new EmbeddingResults.Chunk(embedding, offset));
} }
} }
return new ChunkedInferenceEmbedding(chunks); return new ChunkedInferenceEmbedding(chunks);

View file

@ -20,10 +20,10 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
@ -953,7 +953,7 @@ public class ServiceUtilsTests extends ESTestCase {
var model = mock(Model.class); var model = mock(Model.class);
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
var textEmbedding = TextEmbeddingResultsTests.createRandomResults(); var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults();
doAnswer(invocation -> { doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7); ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);

View file

@ -31,6 +31,7 @@ import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@ -41,7 +42,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderT
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;

View file

@ -65,13 +65,13 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap;
@ -1458,10 +1458,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.123F, 0.678F }, new float[] { 0.123F, 0.678F },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1470,10 +1470,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.223F, 0.278F }, new float[] { 0.223F, 0.278F },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }

View file

@ -1205,10 +1205,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.0123f, -0.0123f }, new float[] { 0.0123f, -0.0123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1217,10 +1217,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 1.0123f, -1.0123f }, new float[] { 1.0123f, -1.0123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }

View file

@ -63,6 +63,7 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -72,7 +73,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap;
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap;
@ -1355,10 +1355,10 @@ public class AzureOpenAiServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1367,10 +1367,10 @@ public class AzureOpenAiServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 1.123f, -1.123f }, new float[] { 1.123f, -1.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }

View file

@ -67,6 +67,7 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -75,7 +76,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
@ -1468,7 +1468,7 @@ public class CohereServiceTests extends ESTestCase {
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertArrayEquals( assertArrayEquals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1479,7 +1479,7 @@ public class CohereServiceTests extends ESTestCase {
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertArrayEquals( assertArrayEquals(
new float[] { 0.223f, -0.223f }, new float[] { 0.223f, -0.223f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1565,16 +1565,22 @@ public class CohereServiceTests extends ESTestCase {
var byteResult = (ChunkedInferenceEmbedding) results.get(0); var byteResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(byteResult.chunks(), hasSize(1)); assertThat(byteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset());
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class)); assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
assertArrayEquals(new byte[] { 23, -23 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding()); assertArrayEquals(
new byte[] { 23, -23 },
((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
);
} }
{ {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var byteResult = (ChunkedInferenceEmbedding) results.get(1); var byteResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(byteResult.chunks(), hasSize(1)); assertThat(byteResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset());
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class)); assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
assertArrayEquals(new byte[] { 24, -24 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding()); assertArrayEquals(
new byte[] { 24, -24 },
((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
);
} }
MatcherAssert.assertThat(webServer.requests(), hasSize(1)); MatcherAssert.assertThat(webServer.requests(), hasSize(1));

View file

@ -37,7 +37,9 @@ import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.InferencePlugin;
@ -48,7 +50,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
@ -645,8 +646,11 @@ public class ElasticInferenceServiceTests extends ESTestCase {
sparseResult.chunks(), sparseResult.chunks(),
is( is(
List.of( List.of(
new SparseEmbeddingResults.Chunk( new EmbeddingResults.Chunk(
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), new SparseEmbeddingResults.Embedding(
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
false
),
new ChunkedInference.TextOffset(0, "input text".length()) new ChunkedInference.TextOffset(0, "input text".length())
) )
) )
@ -767,8 +771,11 @@ public class ElasticInferenceServiceTests extends ESTestCase {
sparseResult.chunks(), sparseResult.chunks(),
is( is(
List.of( List.of(
new SparseEmbeddingResults.Chunk( new EmbeddingResults.Chunk(
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), new SparseEmbeddingResults.Embedding(
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
false
),
new ChunkedInference.TextOffset(0, "input text".length()) new ChunkedInference.TextOffset(0, "input text".length())
) )
) )

View file

@ -896,20 +896,20 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
assertThat(result1.chunks(), hasSize(1)); assertThat(result1.chunks(), hasSize(1));
assertThat(result1.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(),
((TextEmbeddingFloatResults.Chunk) result1.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(),
0.0001f 0.0001f
); );
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks(), hasSize(1)); assertThat(result2.chunks(), hasSize(1));
assertThat(result2.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(),
((TextEmbeddingFloatResults.Chunk) result2.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(),
0.0001f 0.0001f
); );
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
@ -972,18 +972,18 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse, hasSize(2));
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
assertThat(result1.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
assertEquals( assertEquals(
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens() ((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens()
); );
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
assertEquals( assertEquals(
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens() ((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens()
); );
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
gotResults.set(true); gotResults.set(true);
@ -1044,18 +1044,18 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse, hasSize(2));
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
assertThat(result1.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
assertEquals( assertEquals(
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens() ((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens()
); );
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class)); assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
assertEquals( assertEquals(
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens() ((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens()
); );
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
gotResults.set(true); gotResults.set(true);

View file

@ -64,6 +64,7 @@ import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -72,7 +73,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
@ -882,11 +882,11 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.0123f, -0.0123f }, new float[] { 0.0123f, -0.0123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }
@ -897,11 +897,11 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.0456f, -0.0456f }, new float[] { 0.0456f, -0.0456f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }

View file

@ -25,6 +25,7 @@ import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@ -109,8 +110,8 @@ public class HuggingFaceElserServiceTests extends ESTestCase {
sparseResult.chunks(), sparseResult.chunks(),
is( is(
List.of( List.of(
new SparseEmbeddingResults.Chunk( new EmbeddingResults.Chunk(
List.of(new WeightedToken(".", 0.13315596f)), new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(".", 0.13315596f)), false),
new ChunkedInference.TextOffset(0, "abc".length()) new ChunkedInference.TextOffset(0, "abc".length())
) )
) )

View file

@ -35,12 +35,12 @@ import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
@ -58,13 +58,13 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
@ -788,10 +788,10 @@ public class HuggingFaceServiceTests extends ESTestCase {
var embeddingResult = (ChunkedInferenceEmbedding) result; var embeddingResult = (ChunkedInferenceEmbedding) result;
assertThat(embeddingResult.chunks(), hasSize(1)); assertThat(embeddingResult.chunks(), hasSize(1));
assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length())));
assertThat(embeddingResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { -0.0123f, 0.0123f }, new float[] { -0.0123f, 0.0123f },
((TextEmbeddingFloatResults.Chunk) embeddingResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(),
0.001f 0.001f
); );
assertThat(webServer.requests(), hasSize(1)); assertThat(webServer.requests(), hasSize(1));
@ -842,10 +842,10 @@ public class HuggingFaceServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }

View file

@ -68,6 +68,7 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -76,7 +77,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
@ -734,11 +734,11 @@ public class IbmWatsonxServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.0123f, -0.0123f }, new float[] { 0.0123f, -0.0123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }
@ -749,11 +749,11 @@ public class IbmWatsonxServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.0456f, -0.0456f }, new float[] { 0.0456f, -0.0456f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }

View file

@ -63,6 +63,7 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -71,7 +72,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
@ -1833,10 +1833,10 @@ public class JinaAIServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1845,10 +1845,10 @@ public class JinaAIServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.223f, -0.223f }, new float[] { 0.223f, -0.223f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }

View file

@ -684,11 +684,11 @@ public class MistralServiceTests extends ESTestCase {
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }
@ -696,11 +696,11 @@ public class MistralServiceTests extends ESTestCase {
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.223f, -0.223f }, new float[] { 0.223f, -0.223f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }

View file

@ -72,6 +72,7 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
@ -81,7 +82,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap;
@ -1871,11 +1871,11 @@ public class OpenAiServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(0); var floatResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }
@ -1884,11 +1884,11 @@ public class OpenAiServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertTrue( assertTrue(
Arrays.equals( Arrays.equals(
new float[] { 0.223f, -0.223f }, new float[] { 0.223f, -0.223f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding() ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
) )
); );
} }

View file

@ -13,7 +13,7 @@ import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model; import org.elasticsearch.inference.Model;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.junit.Before; import org.junit.Before;
import org.mockito.Mock; import org.mockito.Mock;

View file

@ -14,12 +14,12 @@ import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model; import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests; import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests;
import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
import org.junit.Before; import org.junit.Before;
import org.mockito.Mock; import org.mockito.Mock;

View file

@ -61,6 +61,7 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -69,7 +70,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
@ -1840,10 +1840,10 @@ public class VoyageAIServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.getFirst(); var floatResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset()); assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset());
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.123f, -0.123f }, new float[] { 0.123f, -0.123f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().getFirst()).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }
@ -1852,10 +1852,10 @@ public class VoyageAIServiceTests extends ESTestCase {
var floatResult = (ChunkedInferenceEmbedding) results.get(1); var floatResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(floatResult.chunks(), hasSize(1)); assertThat(floatResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset()); assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset());
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class)); assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
assertArrayEquals( assertArrayEquals(
new float[] { 0.223f, -0.223f }, new float[] { 0.223f, -0.223f },
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().getFirst()).embedding(), ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
0.0f 0.0f
); );
} }