mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Add max.chunks to EmbeddingRequestChunker to prevent OOM (#123150)
* add max number of chunks * wire merge function * implement sparse merge function * move tests to correct package/file * float merge function * bytes merge function * more accurate byte average * spotless * Fix/improve EmbeddingRequestChunkerTests * Remove TODO * remove unnecessary field * remove Chunk generic * add TODO * Remove specialized chunks * add comment * Update docs/changelog/123150.yaml * update changelog
This commit is contained in:
parent
c24f77f547
commit
a503497bce
77 changed files with 756 additions and 355 deletions
5
docs/changelog/123150.yaml
Normal file
5
docs/changelog/123150.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 123150
|
||||
summary: Limit the number of chunks for semantic text to prevent high memory usage
|
||||
area: Machine Learning
|
||||
type: feature
|
||||
issues: []
|
|
@ -17,7 +17,7 @@ import java.util.List;
|
|||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
|
||||
|
||||
public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> chunks) implements ChunkedInference {
|
||||
public record ChunkedInferenceEmbedding(List<EmbeddingResults.Chunk> chunks) implements ChunkedInference {
|
||||
|
||||
public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
|
||||
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());
|
||||
|
@ -27,10 +27,7 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
|
|||
results.add(
|
||||
new ChunkedInferenceEmbedding(
|
||||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
sparseEmbeddingResults.embeddings().get(i).tokens(),
|
||||
new TextOffset(0, inputs.get(i).length())
|
||||
)
|
||||
new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length()))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
@ -41,10 +38,10 @@ public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> c
|
|||
|
||||
@Override
|
||||
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
|
||||
var asChunk = new ArrayList<Chunk>();
|
||||
for (var chunk : chunks()) {
|
||||
asChunk.add(chunk.toChunk(xcontent));
|
||||
List<Chunk> chunkedInferenceChunks = new ArrayList<>();
|
||||
for (EmbeddingResults.Chunk embeddingResultsChunk : chunks()) {
|
||||
chunkedInferenceChunks.add(new Chunk(embeddingResultsChunk.offset(), embeddingResultsChunk.embedding().toBytesRef(xcontent)));
|
||||
}
|
||||
return asChunk.iterator();
|
||||
return chunkedInferenceChunks.iterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.xcontent.XContent;
|
||||
|
@ -19,31 +20,30 @@ import java.util.List;
|
|||
* A call to the inference service may contain multiple input texts, so this results may
|
||||
* contain multiple results.
|
||||
*/
|
||||
public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>>
|
||||
extends
|
||||
InferenceServiceResults {
|
||||
|
||||
/**
|
||||
* A resulting embedding together with the offset into the input text.
|
||||
*/
|
||||
interface Chunk {
|
||||
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
|
||||
|
||||
ChunkedInference.TextOffset offset();
|
||||
}
|
||||
public interface EmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends InferenceServiceResults {
|
||||
|
||||
/**
|
||||
* A resulting embedding for one of the input texts to the inference service.
|
||||
*/
|
||||
interface Embedding<C extends Chunk> {
|
||||
interface Embedding<E extends Embedding<E>> {
|
||||
/**
|
||||
* Combines the resulting embedding with the offset into the input text into a chunk.
|
||||
* Merges the existing embedding and provided embedding into a new embedding.
|
||||
*/
|
||||
C toChunk(ChunkedInference.TextOffset offset);
|
||||
E merge(E embedding);
|
||||
|
||||
/**
|
||||
* Serializes the embedding to bytes.
|
||||
*/
|
||||
BytesReference toBytesRef(XContent xContent) throws IOException;
|
||||
}
|
||||
|
||||
/**
|
||||
* The resulting list of embeddings for the input texts to the inference service.
|
||||
*/
|
||||
List<E> embeddings();
|
||||
|
||||
/**
|
||||
* A resulting embedding together with the offset into the input text.
|
||||
*/
|
||||
record Chunk(Embedding<?> embedding, ChunkedInference.TextOffset offset) {}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
@ -27,17 +26,17 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
|
||||
|
||||
public record SparseEmbeddingResults(List<Embedding> embeddings)
|
||||
implements
|
||||
EmbeddingResults<SparseEmbeddingResults.Chunk, SparseEmbeddingResults.Embedding> {
|
||||
public record SparseEmbeddingResults(List<Embedding> embeddings) implements EmbeddingResults<SparseEmbeddingResults.Embedding> {
|
||||
|
||||
public static final String NAME = "sparse_embedding_results";
|
||||
public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
|
||||
|
@ -124,7 +123,7 @@ public record SparseEmbeddingResults(List<Embedding> embeddings)
|
|||
implements
|
||||
Writeable,
|
||||
ToXContentObject,
|
||||
EmbeddingResults.Embedding<Chunk> {
|
||||
EmbeddingResults.Embedding<Embedding> {
|
||||
|
||||
public static final String EMBEDDING = "embedding";
|
||||
public static final String IS_TRUNCATED = "is_truncated";
|
||||
|
@ -175,18 +174,35 @@ public record SparseEmbeddingResults(List<Embedding> embeddings)
|
|||
}
|
||||
|
||||
@Override
|
||||
public Chunk toChunk(ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(tokens, offset);
|
||||
}
|
||||
}
|
||||
|
||||
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
|
||||
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
|
||||
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
|
||||
public Embedding merge(Embedding embedding) {
|
||||
// This code assumes that the tokens are sorted by weight in descending order.
|
||||
// If that's not the case, the resulting merged embedding will be incorrect.
|
||||
List<WeightedToken> mergedTokens = new ArrayList<>();
|
||||
Set<String> seenTokens = new HashSet<>();
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
// TODO: maybe truncate tokens here when it's getting too large?
|
||||
while (i < tokens().size() || j < embedding.tokens().size()) {
|
||||
WeightedToken token;
|
||||
if (i == tokens().size()) {
|
||||
token = embedding.tokens().get(j++);
|
||||
} else if (j == embedding.tokens().size()) {
|
||||
token = tokens().get(i++);
|
||||
} else if (tokens.get(i).weight() > embedding.tokens().get(j).weight()) {
|
||||
token = tokens().get(i++);
|
||||
} else {
|
||||
token = embedding.tokens().get(j++);
|
||||
}
|
||||
if (seenTokens.add(token.token())) {
|
||||
mergedTokens.add(token);
|
||||
}
|
||||
}
|
||||
boolean mergedIsTruncated = isTruncated || embedding.isTruncated();
|
||||
return new Embedding(mergedTokens, mergedIsTruncated);
|
||||
}
|
||||
|
||||
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
|
||||
@Override
|
||||
public BytesReference toBytesRef(XContent xContent) throws IOException {
|
||||
XContentBuilder b = XContentBuilder.builder(xContent);
|
||||
b.startObject();
|
||||
for (var weightedToken : tokens) {
|
||||
|
|
|
@ -40,9 +40,11 @@ import java.util.Objects;
|
|||
* ]
|
||||
* }
|
||||
*/
|
||||
// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
|
||||
// Embedding.merge method for bits. TODO: implement a proper merge method
|
||||
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
|
||||
implements
|
||||
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
|
||||
TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
|
||||
public static final String NAME = "text_embedding_service_bit_results";
|
||||
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
|
@ -48,9 +47,7 @@ import java.util.Objects;
|
|||
* ]
|
||||
* }
|
||||
*/
|
||||
public record TextEmbeddingByteResults(List<Embedding> embeddings)
|
||||
implements
|
||||
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
|
||||
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
|
||||
public static final String NAME = "text_embedding_service_byte_results";
|
||||
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
|
||||
|
||||
|
@ -118,9 +115,20 @@ public record TextEmbeddingByteResults(List<Embedding> embeddings)
|
|||
return Objects.hash(embeddings);
|
||||
}
|
||||
|
||||
public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> {
|
||||
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
|
||||
// embeddings should happen inbetween serializations.
|
||||
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
|
||||
implements
|
||||
Writeable,
|
||||
ToXContentObject,
|
||||
EmbeddingResults.Embedding<Embedding> {
|
||||
|
||||
public static final String EMBEDDING = "embedding";
|
||||
|
||||
public Embedding(byte[] values) {
|
||||
this(values, null, 1);
|
||||
}
|
||||
|
||||
public Embedding(StreamInput in) throws IOException {
|
||||
this(in.readByteArray());
|
||||
}
|
||||
|
@ -187,25 +195,26 @@ public record TextEmbeddingByteResults(List<Embedding> embeddings)
|
|||
}
|
||||
|
||||
@Override
|
||||
public Chunk toChunk(ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(values, offset);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
|
||||
*/
|
||||
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
|
||||
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
|
||||
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
|
||||
public Embedding merge(Embedding embedding) {
|
||||
byte[] newValues = new byte[values.length];
|
||||
int[] newSumMergedValues = new int[values.length];
|
||||
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
|
||||
+ (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
|
||||
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
|
||||
// closest byte instead of truncating.
|
||||
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
|
||||
}
|
||||
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
|
||||
}
|
||||
|
||||
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
|
||||
@Override
|
||||
public BytesReference toBytesRef(XContent xContent) throws IOException {
|
||||
XContentBuilder builder = XContentBuilder.builder(xContent);
|
||||
builder.startArray();
|
||||
for (byte v : value) {
|
||||
builder.value(v);
|
||||
for (byte value : values) {
|
||||
builder.value(value);
|
||||
}
|
||||
builder.endArray();
|
||||
return BytesReference.bytes(builder);
|
||||
|
|
|
@ -16,7 +16,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
@ -53,9 +52,7 @@ import java.util.stream.Collectors;
|
|||
* ]
|
||||
* }
|
||||
*/
|
||||
public record TextEmbeddingFloatResults(List<Embedding> embeddings)
|
||||
implements
|
||||
TextEmbeddingResults<TextEmbeddingFloatResults.Chunk, TextEmbeddingFloatResults.Embedding> {
|
||||
public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
|
||||
public static final String NAME = "text_embedding_service_results";
|
||||
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
|
||||
|
||||
|
@ -155,9 +152,19 @@ public record TextEmbeddingFloatResults(List<Embedding> embeddings)
|
|||
return Objects.hash(embeddings);
|
||||
}
|
||||
|
||||
public record Embedding(float[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> {
|
||||
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
|
||||
// embeddings should happen inbetween serializations.
|
||||
public record Embedding(float[] values, int numberOfMergedEmbeddings)
|
||||
implements
|
||||
Writeable,
|
||||
ToXContentObject,
|
||||
EmbeddingResults.Embedding<Embedding> {
|
||||
public static final String EMBEDDING = "embedding";
|
||||
|
||||
public Embedding(float[] values) {
|
||||
this(values, 1);
|
||||
}
|
||||
|
||||
public Embedding(StreamInput in) throws IOException {
|
||||
this(in.readFloatArray());
|
||||
}
|
||||
|
@ -221,25 +228,21 @@ public record TextEmbeddingFloatResults(List<Embedding> embeddings)
|
|||
}
|
||||
|
||||
@Override
|
||||
public Chunk toChunk(ChunkedInference.TextOffset offset) {
|
||||
return new Chunk(values, offset);
|
||||
}
|
||||
}
|
||||
|
||||
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
|
||||
|
||||
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
|
||||
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
|
||||
public Embedding merge(Embedding embedding) {
|
||||
float[] mergedValues = new float[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
mergedValues[i] = (numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i])
|
||||
/ (numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
|
||||
}
|
||||
return new Embedding(mergedValues, numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
|
||||
*/
|
||||
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
|
||||
@Override
|
||||
public BytesReference toBytesRef(XContent xContent) throws IOException {
|
||||
XContentBuilder b = XContentBuilder.builder(xContent);
|
||||
b.startArray();
|
||||
for (float v : value) {
|
||||
b.value(v);
|
||||
for (float value : values) {
|
||||
b.value(value);
|
||||
}
|
||||
b.endArray();
|
||||
return BytesReference.bytes(b);
|
||||
|
|
|
@ -7,9 +7,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
public interface TextEmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>>
|
||||
extends
|
||||
EmbeddingResults<C, E> {
|
||||
public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
|
||||
|
||||
/**
|
||||
* Returns the first text embedding entry in the result list's array size.
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
@ -14,7 +14,6 @@ import org.elasticsearch.xcontent.ToXContentFragment;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
|
@ -5,12 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
|
||||
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
|
||||
|
||||
|
@ -20,6 +19,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase<SparseEmbeddingResults> {
|
||||
|
@ -161,6 +161,44 @@ public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase
|
|||
);
|
||||
}
|
||||
|
||||
public void testEmbeddingMerge() {
|
||||
SparseEmbeddingResults.Embedding embedding1 = new SparseEmbeddingResults.Embedding(
|
||||
List.of(
|
||||
new WeightedToken("this", 1.0f),
|
||||
new WeightedToken("is", 0.8f),
|
||||
new WeightedToken("the", 0.6f),
|
||||
new WeightedToken("first", 0.4f),
|
||||
new WeightedToken("embedding", 0.2f)
|
||||
),
|
||||
true
|
||||
);
|
||||
SparseEmbeddingResults.Embedding embedding2 = new SparseEmbeddingResults.Embedding(
|
||||
List.of(
|
||||
new WeightedToken("this", 0.95f),
|
||||
new WeightedToken("is", 0.85f),
|
||||
new WeightedToken("another", 0.65f),
|
||||
new WeightedToken("embedding", 0.15f)
|
||||
),
|
||||
false
|
||||
);
|
||||
assertThat(
|
||||
embedding1.merge(embedding2),
|
||||
equalTo(
|
||||
new SparseEmbeddingResults.Embedding(
|
||||
List.of(
|
||||
new WeightedToken("this", 1.0f),
|
||||
new WeightedToken("is", 0.85f),
|
||||
new WeightedToken("another", 0.65f),
|
||||
new WeightedToken("the", 0.6f),
|
||||
new WeightedToken("first", 0.4f),
|
||||
new WeightedToken("embedding", 0.2f)
|
||||
),
|
||||
true
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public record EmbeddingExpectation(Map<String, Float> tokens, boolean isTruncated) {}
|
||||
|
||||
public static Map<String, Object> buildExpectationSparseEmbeddings(List<EmbeddingExpectation> embeddings) {
|
|
@ -5,13 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
|
@ -5,12 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -18,6 +17,7 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingByteResults> {
|
||||
|
@ -115,6 +115,16 @@ public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCa
|
|||
assertThat(firstEmbeddingSize, is(2));
|
||||
}
|
||||
|
||||
public void testEmbeddingMerge() {
|
||||
TextEmbeddingByteResults.Embedding embedding1 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, -128 });
|
||||
TextEmbeddingByteResults.Embedding embedding2 = new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 127 });
|
||||
TextEmbeddingByteResults.Embedding embedding3 = new TextEmbeddingByteResults.Embedding(new byte[] { 0, 0, 100 });
|
||||
TextEmbeddingByteResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
|
||||
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 1, 0 })));
|
||||
mergedEmbedding = mergedEmbedding.merge(embedding3);
|
||||
assertThat(mergedEmbedding, equalTo(new TextEmbeddingByteResults.Embedding(new byte[] { 1, 0, 33 })));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() {
|
||||
return TextEmbeddingByteResults::new;
|
|
@ -5,13 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.results;
|
||||
package org.elasticsearch.xpack.core.inference.results;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -19,9 +17,10 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingFloatResults> {
|
||||
public class TextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase<TextEmbeddingFloatResults> {
|
||||
public static TextEmbeddingFloatResults createRandomResults() {
|
||||
int embeddings = randomIntBetween(1, 10);
|
||||
List<TextEmbeddingFloatResults.Embedding> embeddingResults = new ArrayList<>(embeddings);
|
||||
|
@ -116,6 +115,16 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<T
|
|||
assertThat(firstEmbeddingSize, is(2));
|
||||
}
|
||||
|
||||
public void testEmbeddingMerge() {
|
||||
TextEmbeddingFloatResults.Embedding embedding1 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.2f, 0.3f, 0.4f });
|
||||
TextEmbeddingFloatResults.Embedding embedding2 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.0f, 0.4f, 0.1f, 1.0f });
|
||||
TextEmbeddingFloatResults.Embedding embedding3 = new TextEmbeddingFloatResults.Embedding(new float[] { 0.2f, 0.9f, 0.8f, 0.1f });
|
||||
TextEmbeddingFloatResults.Embedding mergedEmbedding = embedding1.merge(embedding2);
|
||||
assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.05f, 0.3f, 0.2f, 0.7f })));
|
||||
mergedEmbedding = mergedEmbedding.merge(embedding3);
|
||||
assertThat(mergedEmbedding, equalTo(new TextEmbeddingFloatResults.Embedding(new float[] { 0.1f, 0.5f, 0.4f, 0.5f })));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TextEmbeddingFloatResults> instanceReader() {
|
||||
return TextEmbeddingFloatResults::new;
|
|
@ -35,6 +35,7 @@ import org.elasticsearch.rest.RestStatus;
|
|||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -181,8 +182,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
|
|||
results.add(
|
||||
new ChunkedInferenceEmbedding(
|
||||
List.of(
|
||||
new TextEmbeddingFloatResults.Chunk(
|
||||
nonChunkedResults.embeddings().get(i).values(),
|
||||
new EmbeddingResults.Chunk(
|
||||
nonChunkedResults.embeddings().get(i),
|
||||
new ChunkedInference.TextOffset(0, input.get(i).length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.elasticsearch.rest.RestStatus;
|
|||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
|
||||
|
||||
|
@ -172,7 +173,12 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
|
|||
}
|
||||
results.add(
|
||||
new ChunkedInferenceEmbedding(
|
||||
List.of(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.get(i).length())))
|
||||
List.of(
|
||||
new EmbeddingResults.Chunk(
|
||||
new SparseEmbeddingResults.Embedding(tokens, false),
|
||||
new ChunkedInference.TextOffset(0, input.get(i).length())
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ import java.util.stream.Collectors;
|
|||
* processing and map the results back to the original element
|
||||
* in the input list.
|
||||
*/
|
||||
public class EmbeddingRequestChunker {
|
||||
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
|
||||
|
||||
// Visible for testing
|
||||
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<String> inputs) {
|
||||
|
@ -56,13 +56,19 @@ public class EmbeddingRequestChunker {
|
|||
private static final int DEFAULT_WORDS_PER_CHUNK = 250;
|
||||
private static final int DEFAULT_CHUNK_OVERLAP = 100;
|
||||
|
||||
private final List<String> inputs;
|
||||
private final List<List<Request>> requests;
|
||||
// The maximum number of chunks that is stored for any input text.
|
||||
// If the configured chunker chunks the text into more chunks, each
|
||||
// chunk is sent to the inference service separately, but the results
|
||||
// are merged so that only this maximum number of chunks is stored.
|
||||
private static final int MAX_CHUNKS = 512;
|
||||
|
||||
private final List<BatchRequest> batchRequests;
|
||||
private final AtomicInteger resultCount = new AtomicInteger();
|
||||
|
||||
private final List<AtomicReferenceArray<EmbeddingResults.Embedding<?>>> results;
|
||||
private final AtomicArray<Exception> errors;
|
||||
private final List<List<Integer>> resultOffsetStarts;
|
||||
private final List<List<Integer>> resultOffsetEnds;
|
||||
private final List<AtomicReferenceArray<E>> resultEmbeddings;
|
||||
private final AtomicArray<Exception> resultsErrors;
|
||||
private ActionListener<List<ChunkedInference>> finalListener;
|
||||
|
||||
public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch) {
|
||||
|
@ -74,31 +80,41 @@ public class EmbeddingRequestChunker {
|
|||
}
|
||||
|
||||
public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) {
|
||||
this.inputs = inputs;
|
||||
this.results = new ArrayList<>(inputs.size());
|
||||
this.errors = new AtomicArray<>(inputs.size());
|
||||
this.resultEmbeddings = new ArrayList<>(inputs.size());
|
||||
this.resultOffsetStarts = new ArrayList<>(inputs.size());
|
||||
this.resultOffsetEnds = new ArrayList<>(inputs.size());
|
||||
this.resultsErrors = new AtomicArray<>(inputs.size());
|
||||
|
||||
if (chunkingSettings == null) {
|
||||
chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP);
|
||||
}
|
||||
Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
|
||||
|
||||
this.requests = new ArrayList<>(inputs.size());
|
||||
|
||||
List<Request> allRequests = new ArrayList<>();
|
||||
for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) {
|
||||
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings);
|
||||
List<Request> requestForInput = new ArrayList<>(chunks.size());
|
||||
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
|
||||
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
|
||||
resultOffsetStarts.add(new ArrayList<>(resultCount));
|
||||
resultOffsetEnds.add(new ArrayList<>(resultCount));
|
||||
|
||||
for (int chunkIndex = 0; chunkIndex < chunks.size(); chunkIndex++) {
|
||||
requestForInput.add(new Request(inputIndex, chunkIndex, chunks.get(chunkIndex), inputs));
|
||||
// If the number of chunks is larger than the maximum allowed value,
|
||||
// scale the indices to [0, MAX) with similar number of original
|
||||
// chunks in the final chunks.
|
||||
int targetChunkIndex = chunks.size() <= MAX_CHUNKS ? chunkIndex : chunkIndex * MAX_CHUNKS / chunks.size();
|
||||
if (resultOffsetStarts.getLast().size() <= targetChunkIndex) {
|
||||
resultOffsetStarts.getLast().add(chunks.get(chunkIndex).start());
|
||||
resultOffsetEnds.getLast().add(chunks.get(chunkIndex).end());
|
||||
} else {
|
||||
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
|
||||
}
|
||||
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
|
||||
}
|
||||
requests.add(requestForInput);
|
||||
// size the results array with the expected number of request/responses
|
||||
results.add(new AtomicReferenceArray<>(chunks.size()));
|
||||
}
|
||||
|
||||
AtomicInteger counter = new AtomicInteger();
|
||||
this.batchRequests = requests.stream()
|
||||
.flatMap(List::stream)
|
||||
this.batchRequests = allRequests.stream()
|
||||
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
|
||||
.values()
|
||||
.stream()
|
||||
|
@ -134,20 +150,26 @@ public class EmbeddingRequestChunker {
|
|||
|
||||
@Override
|
||||
public void onResponse(InferenceServiceResults inferenceServiceResults) {
|
||||
if (inferenceServiceResults instanceof EmbeddingResults<?, ?> embeddingResults) {
|
||||
if (embeddingResults.embeddings().size() != request.requests.size()) {
|
||||
onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size()));
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < embeddingResults.embeddings().size(); i++) {
|
||||
results.get(request.requests().get(i).inputIndex())
|
||||
.set(request.requests().get(i).chunkIndex(), embeddingResults.embeddings().get(i));
|
||||
}
|
||||
if (resultCount.incrementAndGet() == batchRequests.size()) {
|
||||
sendFinalResponse();
|
||||
}
|
||||
} else {
|
||||
if (inferenceServiceResults instanceof EmbeddingResults<?> == false) {
|
||||
onFailure(unexpectedResultTypeException(inferenceServiceResults.getWriteableName()));
|
||||
return;
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
EmbeddingResults<E> embeddingResults = (EmbeddingResults<E>) inferenceServiceResults;
|
||||
if (embeddingResults.embeddings().size() != request.requests.size()) {
|
||||
onFailure(numResultsDoesntMatchException(embeddingResults.embeddings().size(), request.requests.size()));
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < embeddingResults.embeddings().size(); i++) {
|
||||
E newEmbedding = embeddingResults.embeddings().get(i);
|
||||
resultEmbeddings.get(request.requests().get(i).inputIndex())
|
||||
.updateAndGet(
|
||||
request.requests().get(i).chunkIndex(),
|
||||
oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding)
|
||||
);
|
||||
}
|
||||
if (resultCount.incrementAndGet() == batchRequests.size()) {
|
||||
sendFinalResponse();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,7 +193,7 @@ public class EmbeddingRequestChunker {
|
|||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
for (Request request : request.requests) {
|
||||
errors.set(request.inputIndex(), e);
|
||||
resultsErrors.set(request.inputIndex(), e);
|
||||
}
|
||||
if (resultCount.incrementAndGet() == batchRequests.size()) {
|
||||
sendFinalResponse();
|
||||
|
@ -180,10 +202,10 @@ public class EmbeddingRequestChunker {
|
|||
}
|
||||
|
||||
private void sendFinalResponse() {
|
||||
var response = new ArrayList<ChunkedInference>(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (errors.get(i) != null) {
|
||||
response.add(new ChunkedInferenceError(errors.get(i)));
|
||||
var response = new ArrayList<ChunkedInference>(resultEmbeddings.size());
|
||||
for (int i = 0; i < resultEmbeddings.size(); i++) {
|
||||
if (resultsErrors.get(i) != null) {
|
||||
response.add(new ChunkedInferenceError(resultsErrors.get(i)));
|
||||
} else {
|
||||
response.add(mergeResultsWithInputs(i));
|
||||
}
|
||||
|
@ -191,14 +213,15 @@ public class EmbeddingRequestChunker {
|
|||
finalListener.onResponse(response);
|
||||
}
|
||||
|
||||
private ChunkedInference mergeResultsWithInputs(int index) {
|
||||
private ChunkedInference mergeResultsWithInputs(int inputIndex) {
|
||||
List<Integer> startOffsets = resultOffsetStarts.get(inputIndex);
|
||||
List<Integer> endOffsets = resultOffsetEnds.get(inputIndex);
|
||||
AtomicReferenceArray<E> embeddings = resultEmbeddings.get(inputIndex);
|
||||
|
||||
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
List<Request> request = requests.get(index);
|
||||
AtomicReferenceArray<EmbeddingResults.Embedding<?>> result = results.get(index);
|
||||
for (int i = 0; i < request.size(); i++) {
|
||||
EmbeddingResults.Chunk chunk = result.get(i)
|
||||
.toChunk(new ChunkedInference.TextOffset(request.get(i).chunk.start(), request.get(i).chunk.end()));
|
||||
chunks.add(chunk);
|
||||
for (int i = 0; i < embeddings.length(); i++) {
|
||||
ChunkedInference.TextOffset offset = new ChunkedInference.TextOffset(startOffsets.get(i), endOffsets.get(i));
|
||||
chunks.add(new EmbeddingResults.Chunk(embeddings.get(i), offset));
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
|
|
@ -741,7 +741,7 @@ public final class ServiceUtils {
|
|||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener.delegateFailureAndWrap((delegate, r) -> {
|
||||
if (r instanceof TextEmbeddingResults<?, ?> embeddingResults) {
|
||||
if (r instanceof TextEmbeddingResults<?> embeddingResults) {
|
||||
try {
|
||||
delegate.onResponse(embeddingResults.getFirstEmbeddingSize());
|
||||
} catch (Exception e) {
|
||||
|
|
|
@ -305,7 +305,7 @@ public class AlibabaCloudSearchService extends SenderService {
|
|||
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
|
||||
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -129,7 +129,7 @@ public class AmazonBedrockService extends SenderService {
|
|||
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
|
||||
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
maxBatchSize,
|
||||
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -121,7 +121,7 @@ public class AzureAiStudioService extends SenderService {
|
|||
if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
|
||||
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
baseAzureAiStudioModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -281,7 +281,7 @@ public class AzureOpenAiService extends SenderService {
|
|||
AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
|
||||
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
azureOpenAiModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -284,7 +284,7 @@ public class CohereService extends SenderService {
|
|||
CohereModel cohereModel = (CohereModel) model;
|
||||
var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
cohereModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -725,7 +725,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
|||
|
||||
if (model instanceof ElasticsearchInternalModel esModel) {
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
input,
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
esModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -325,7 +325,7 @@ public class GoogleAiStudioService extends SenderService {
|
|||
) {
|
||||
GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
googleAiStudioModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -228,7 +228,7 @@ public class GoogleVertexAiService extends SenderService {
|
|||
GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;
|
||||
var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
googleVertexAiModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -127,7 +127,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
var huggingFaceModel = (HuggingFaceModel) model;
|
||||
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
huggingFaceModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
|||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
||||
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
||||
|
@ -119,8 +120,8 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
|
|||
results.add(
|
||||
new ChunkedInferenceEmbedding(
|
||||
List.of(
|
||||
new TextEmbeddingFloatResults.Chunk(
|
||||
textEmbeddingResults.embeddings().get(i).values(),
|
||||
new EmbeddingResults.Chunk(
|
||||
textEmbeddingResults.embeddings().get(i),
|
||||
new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -307,7 +307,7 @@ public class IbmWatsonxService extends SenderService {
|
|||
) {
|
||||
IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model;
|
||||
|
||||
var batchedRequests = new EmbeddingRequestChunker(
|
||||
var batchedRequests = new EmbeddingRequestChunker<>(
|
||||
input.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
model.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -266,7 +266,7 @@ public class JinaAIService extends SenderService {
|
|||
JinaAIModel jinaaiModel = (JinaAIModel) model;
|
||||
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
jinaaiModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -110,7 +110,7 @@ public class MistralService extends SenderService {
|
|||
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
MistralConstants.MAX_BATCH_SIZE,
|
||||
mistralEmbeddingsModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -335,7 +335,7 @@ public class OpenAiService extends SenderService {
|
|||
OpenAiModel openAiModel = (OpenAiModel) model;
|
||||
var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
EMBEDDING_MAX_BATCH_SIZE,
|
||||
openAiModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -33,7 +33,7 @@ public class TextEmbeddingModelValidator implements ModelValidator {
|
|||
}
|
||||
|
||||
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
|
||||
if (results instanceof TextEmbeddingResults<?, ?> embeddingResults) {
|
||||
if (results instanceof TextEmbeddingResults<?> embeddingResults) {
|
||||
var serviceSettings = model.getServiceSettings();
|
||||
var dimensions = serviceSettings.dimensions();
|
||||
int embeddingSize = getEmbeddingSize(embeddingResults);
|
||||
|
@ -67,7 +67,7 @@ public class TextEmbeddingModelValidator implements ModelValidator {
|
|||
}
|
||||
}
|
||||
|
||||
private int getEmbeddingSize(TextEmbeddingResults<?, ?> embeddingResults) {
|
||||
private int getEmbeddingSize(TextEmbeddingResults<?> embeddingResults) {
|
||||
int embeddingSize;
|
||||
try {
|
||||
embeddingSize = embeddingResults.getFirstEmbeddingSize();
|
||||
|
|
|
@ -288,7 +288,7 @@ public class VoyageAIService extends SenderService {
|
|||
VoyageAIModel voyageaiModel = (VoyageAIModel) model;
|
||||
var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||
inputs.getInputs(),
|
||||
getBatchSize(voyageaiModel),
|
||||
voyageaiModel.getConfigurations().getChunkingSettings()
|
||||
|
|
|
@ -11,12 +11,12 @@ import org.elasticsearch.TransportVersion;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.LegacyMlTextEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
|
||||
import org.elasticsearch.xpack.inference.results.LegacyMlTextEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -43,7 +43,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe
|
|||
@Override
|
||||
protected InferenceAction.Response createTestInstance() {
|
||||
var result = switch (randomIntBetween(0, 2)) {
|
||||
case 0 -> TextEmbeddingResultsTests.createRandomResults();
|
||||
case 0 -> TextEmbeddingFloatResultsTests.createRandomResults();
|
||||
case 1 -> LegacyMlTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults();
|
||||
default -> SparseEmbeddingResultsTests.createRandomResults();
|
||||
};
|
||||
|
@ -87,7 +87,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe
|
|||
}
|
||||
|
||||
public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException {
|
||||
var embeddingResults = TextEmbeddingResultsTests.createRandomResults();
|
||||
var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults();
|
||||
var instance = new InferenceAction.Response(embeddingResults);
|
||||
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
|
||||
assertOnBWCObject(copy, instance, V_8_12_0);
|
||||
|
@ -103,7 +103,7 @@ public class InferenceActionResponseTests extends AbstractBWCWireSerializationTe
|
|||
// Technically we should never see a text embedding result in the transport version of this test because support
|
||||
// for it wasn't added until openai
|
||||
public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException {
|
||||
var embeddingResults = TextEmbeddingResultsTests.createRandomResults();
|
||||
var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults();
|
||||
var instance = new InferenceAction.Response(embeddingResults);
|
||||
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
|
||||
assertOnBWCObject(copy, instance, V_8_12_0);
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference;
|
|||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.endsWith;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
@ -33,19 +34,19 @@ import static org.hamcrest.Matchers.startsWith;
|
|||
public class EmbeddingRequestChunkerTests extends ESTestCase {
|
||||
|
||||
public void testEmptyInput_WordChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, empty());
|
||||
}
|
||||
|
||||
public void testEmptyInput_SentenceChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of(), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners(
|
||||
var batches = new EmbeddingRequestChunker<>(List.of(), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners(
|
||||
testListener()
|
||||
);
|
||||
assertThat(batches, empty());
|
||||
}
|
||||
|
||||
public void testWhitespaceInput_SentenceChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1))
|
||||
var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1))
|
||||
.batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
||||
|
@ -53,30 +54,29 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testBlankInput_WordChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
|
||||
}
|
||||
|
||||
public void testBlankInput_SentenceChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)).batchRequestsWithListeners(
|
||||
testListener()
|
||||
);
|
||||
var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1))
|
||||
.batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
|
||||
}
|
||||
|
||||
public void testInputThatDoesNotChunk_WordChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
|
||||
}
|
||||
|
||||
public void testInputThatDoesNotChunk_SentenceChunker() {
|
||||
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1))
|
||||
var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1))
|
||||
.batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
||||
|
@ -85,14 +85,14 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
|
||||
public void testShortInputsAreSingleBatch() {
|
||||
String input = "one chunk";
|
||||
var batches = new EmbeddingRequestChunker(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
assertThat(batches.get(0).batch().inputs(), contains(input));
|
||||
}
|
||||
|
||||
public void testMultipleShortInputsAreSingleBatch() {
|
||||
List<String> inputs = List.of("1st small", "2nd small", "3rd small");
|
||||
var batches = new EmbeddingRequestChunker(inputs, 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(1));
|
||||
EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch();
|
||||
assertEquals(batch.inputs(), inputs);
|
||||
|
@ -113,7 +113,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
inputs.add("input " + i);
|
||||
}
|
||||
|
||||
var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(4));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
|
||||
assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
|
||||
|
@ -148,7 +148,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
inputs.add("input " + i);
|
||||
}
|
||||
|
||||
var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings())
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings())
|
||||
.batchRequestsWithListeners(testListener());
|
||||
assertThat(batches, hasSize(4));
|
||||
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
|
||||
|
@ -190,7 +190,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener());
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener());
|
||||
|
||||
assertThat(batches, hasSize(2));
|
||||
|
||||
|
@ -234,6 +234,260 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(request.chunkText(), equalTo("3rd small"));
|
||||
}
|
||||
|
||||
public void testVeryLongInput_Sparse() {
|
||||
int batchSize = 5;
|
||||
int chunkSize = 20;
|
||||
int numberOfWordsInPassage = (chunkSize * 10000);
|
||||
|
||||
var passageBuilder = new StringBuilder();
|
||||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
|
||||
|
||||
var finalListener = testListener();
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
|
||||
.batchRequestsWithListeners(finalListener);
|
||||
|
||||
// The very long passage is split into 10000 chunks for inference, so
|
||||
// there are 10002 inference requests, resulting in 2001 batches.
|
||||
assertThat(batches, hasSize(2001));
|
||||
for (int i = 0; i < 2000; i++) {
|
||||
assertThat(batches.get(i).batch().inputs(), hasSize(5));
|
||||
}
|
||||
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
|
||||
|
||||
// Produce inference results for each request, with just the token
|
||||
// "word" and increasing weights.
|
||||
float weight = 0f;
|
||||
for (var batch : batches) {
|
||||
var embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
|
||||
for (int i = 0; i < batch.batch().requests().size(); i++) {
|
||||
weight += 1 / 16384f;
|
||||
embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("word", weight)), false));
|
||||
}
|
||||
batch.listener().onResponse(new SparseEmbeddingResults(embeddings));
|
||||
}
|
||||
|
||||
assertNotNull(finalListener.results);
|
||||
assertThat(finalListener.results, hasSize(3));
|
||||
|
||||
// The first input has the token with weight 1/16384f.
|
||||
ChunkedInference inference = finalListener.results.get(0);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(1));
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f)));
|
||||
|
||||
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
|
||||
// inference. They get the embeddings with token "word" and weights 2/1024 ... 10000/16384.
|
||||
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
|
||||
// and therefore 380 or 400 words. For each, the max token weights are collected.
|
||||
inference = finalListener.results.get(1);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(512));
|
||||
|
||||
// The first merged chunk consists of 20 small chunks (so 400 words) and the max
|
||||
// weight is the weight of the 20th small chunk (so 21/16384).
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f)));
|
||||
|
||||
// The last merged chunk consists of 19 small chunks (so 380 words) and the max
|
||||
// weight is the weight of the 10000th small chunk (so 10001/16384).
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
|
||||
assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
|
||||
assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f)));
|
||||
|
||||
// The last input has the token with weight 10002/16384.
|
||||
inference = finalListener.results.get(2);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(1));
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f)));
|
||||
}
|
||||
|
||||
public void testVeryLongInput_Float() {
|
||||
int batchSize = 5;
|
||||
int chunkSize = 20;
|
||||
int numberOfWordsInPassage = (chunkSize * 10000);
|
||||
|
||||
var passageBuilder = new StringBuilder();
|
||||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
|
||||
|
||||
var finalListener = testListener();
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
|
||||
.batchRequestsWithListeners(finalListener);
|
||||
|
||||
// The very long passage is split into 10000 chunks for inference, so
|
||||
// there are 10002 inference requests, resulting in 2001 batches.
|
||||
assertThat(batches, hasSize(2001));
|
||||
for (int i = 0; i < 2000; i++) {
|
||||
assertThat(batches.get(i).batch().inputs(), hasSize(5));
|
||||
}
|
||||
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
|
||||
|
||||
// Produce inference results for each request, with increasing weights.
|
||||
float weight = 0f;
|
||||
for (var batch : batches) {
|
||||
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
|
||||
for (int i = 0; i < batch.batch().requests().size(); i++) {
|
||||
weight += 1 / 16384f;
|
||||
embeddings.add(new TextEmbeddingFloatResults.Embedding(new float[] { weight }));
|
||||
}
|
||||
batch.listener().onResponse(new TextEmbeddingFloatResults(embeddings));
|
||||
}
|
||||
|
||||
assertNotNull(finalListener.results);
|
||||
assertThat(finalListener.results, hasSize(3));
|
||||
|
||||
// The first input has the embedding with weight 1/16384.
|
||||
ChunkedInference inference = finalListener.results.get(0);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(1));
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f }));
|
||||
|
||||
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
|
||||
// inference. They get the embeddings with weights 2/1024 ... 10000/16384.
|
||||
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
|
||||
// and therefore 380 or 400 words. For each, the average weight is collected.
|
||||
inference = finalListener.results.get(1);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(512));
|
||||
|
||||
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
|
||||
// is the average of the weights 2/16384 ... 21/16384.
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) }));
|
||||
|
||||
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
|
||||
// is the average of the weights 9983/16384 ... 10001/16384.
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
|
||||
assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
|
||||
assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) }));
|
||||
|
||||
// The last input has the token with weight 10002/16384.
|
||||
inference = finalListener.results.get(2);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(1));
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f }));
|
||||
}
|
||||
|
||||
public void testVeryLongInput_Byte() {
|
||||
int batchSize = 5;
|
||||
int chunkSize = 20;
|
||||
int numberOfWordsInPassage = (chunkSize * 10000);
|
||||
|
||||
var passageBuilder = new StringBuilder();
|
||||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("word").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small");
|
||||
|
||||
var finalListener = testListener();
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0)
|
||||
.batchRequestsWithListeners(finalListener);
|
||||
|
||||
// The very long passage is split into 10000 chunks for inference, so
|
||||
// there are 10002 inference requests, resulting in 2001 batches.
|
||||
assertThat(batches, hasSize(2001));
|
||||
for (int i = 0; i < 2000; i++) {
|
||||
assertThat(batches.get(i).batch().inputs(), hasSize(5));
|
||||
}
|
||||
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
|
||||
|
||||
// Produce inference results for each request, with increasing weights.
|
||||
byte weight = 0;
|
||||
for (var batch : batches) {
|
||||
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
|
||||
for (int i = 0; i < batch.batch().requests().size(); i++) {
|
||||
weight += 1;
|
||||
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { weight }));
|
||||
}
|
||||
batch.listener().onResponse(new TextEmbeddingByteResults(embeddings));
|
||||
}
|
||||
|
||||
assertNotNull(finalListener.results);
|
||||
assertThat(finalListener.results, hasSize(3));
|
||||
|
||||
// The first input has the embedding with weight 1.
|
||||
ChunkedInference inference = finalListener.results.get(0);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(1));
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
|
||||
TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.values(), equalTo(new byte[] { 1 }));
|
||||
|
||||
// The very long passage "word0 word1 ... word199999" is split into 10000 chunks for
|
||||
// inference. They get the embeddings with weights 2/1024 ... 10000/16384.
|
||||
// Next, they are merged into 512 larger chunks, which consists of 19 or 20 smaller chunks
|
||||
// and therefore 380 or 400 words. For each, the average weight is collected.
|
||||
inference = finalListener.results.get(1);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(512));
|
||||
|
||||
// The first merged chunk consists of 20 small chunks (so 400 words) and the weight
|
||||
// is the average of the weights 2 ... 21, so 11.5, which is rounded to 12.
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
|
||||
embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.values(), equalTo(new byte[] { 12 }));
|
||||
|
||||
// The last merged chunk consists of 19 small chunks (so 380 words) and the weight
|
||||
// is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so
|
||||
// the average of -1, 0, 1, ... , 17, so 8.
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999"));
|
||||
assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
|
||||
embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding();
|
||||
assertThat(embedding.values(), equalTo(new byte[] { 8 }));
|
||||
|
||||
// The last input has the token with weight 10002 % 256 = 18
|
||||
inference = finalListener.results.get(2);
|
||||
assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
chunkedEmbedding = (ChunkedInferenceEmbedding) inference;
|
||||
assertThat(chunkedEmbedding.chunks(), hasSize(1));
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
|
||||
embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding();
|
||||
assertThat(embedding.values(), equalTo(new byte[] { 18 }));
|
||||
}
|
||||
|
||||
public void testMergingListener_Float() {
|
||||
int batchSize = 5;
|
||||
int chunkSize = 20;
|
||||
|
@ -246,10 +500,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
assertThat(batches, hasSize(2));
|
||||
|
||||
// 4 inputs in 2 batches
|
||||
|
@ -275,7 +529,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedFloatResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small"));
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -283,26 +537,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(6));
|
||||
assertThat(chunkedFloatResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
|
||||
assertThat(chunkedFloatResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
|
||||
assertThat(chunkedFloatResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
|
||||
assertThat(chunkedFloatResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
|
||||
assertThat(chunkedFloatResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
|
||||
assertThat(chunkedFloatResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 "));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedFloatResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(3);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedFloatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedFloatResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(3), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -318,10 +572,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
assertThat(batches, hasSize(2));
|
||||
|
||||
// 4 inputs in 2 batches
|
||||
|
@ -347,7 +601,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small"));
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -355,26 +609,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(6));
|
||||
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
|
||||
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
|
||||
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
|
||||
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
|
||||
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
|
||||
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 "));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(3);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -390,10 +644,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("a", passageBuilder.toString(), "bb", "ccc");
|
||||
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
assertThat(batches, hasSize(2));
|
||||
|
||||
// 4 inputs in 2 batches
|
||||
|
@ -419,7 +673,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedByteResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small"));
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -427,26 +681,26 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(6));
|
||||
assertThat(chunkedByteResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 309)));
|
||||
assertThat(chunkedByteResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(309, 629)));
|
||||
assertThat(chunkedByteResult.chunks().get(2).offset(), equalTo(new ChunkedInference.TextOffset(629, 949)));
|
||||
assertThat(chunkedByteResult.chunks().get(3).offset(), equalTo(new ChunkedInference.TextOffset(949, 1269)));
|
||||
assertThat(chunkedByteResult.chunks().get(4).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1589)));
|
||||
assertThat(chunkedByteResult.chunks().get(5).offset(), equalTo(new ChunkedInference.TextOffset(1589, 1675)));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 "));
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 "));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedByteResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(3);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedByteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedByteResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -462,10 +716,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
for (int i = 0; i < numberOfWordsInPassage; i++) {
|
||||
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
|
||||
}
|
||||
List<String> inputs = List.of("a", "bb", "ccc", passageBuilder.toString());
|
||||
List<String> inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString());
|
||||
|
||||
var finalListener = testListener();
|
||||
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
|
||||
assertThat(batches, hasSize(3));
|
||||
|
||||
// 4 inputs in 3 batches
|
||||
|
@ -498,21 +752,21 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), chunkedSparseResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(0), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small"));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(1);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), chunkedSparseResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(1), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small"));
|
||||
}
|
||||
{
|
||||
var chunkedResult = finalListener.results.get(2);
|
||||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), chunkedSparseResult.chunks().get(0).offset());
|
||||
assertThat(getMatchedText(inputs.get(2), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small"));
|
||||
}
|
||||
{
|
||||
// this is the large input split in multiple chunks
|
||||
|
@ -520,9 +774,9 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult;
|
||||
assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each
|
||||
assertThat(chunkedSparseResult.chunks().get(0).offset(), equalTo(new ChunkedInference.TextOffset(0, 149)));
|
||||
assertThat(chunkedSparseResult.chunks().get(1).offset(), equalTo(new ChunkedInference.TextOffset(149, 309)));
|
||||
assertThat(chunkedSparseResult.chunks().get(8).offset(), equalTo(new ChunkedInference.TextOffset(1269, 1350)));
|
||||
assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 "));
|
||||
assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 "));
|
||||
assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 "));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -545,7 +799,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
}
|
||||
};
|
||||
|
||||
var batches = new EmbeddingRequestChunker(inputs, 10, 100, 0).batchRequestsWithListeners(listener);
|
||||
var batches = new EmbeddingRequestChunker<>(inputs, 10, 100, 0).batchRequestsWithListeners(listener);
|
||||
assertThat(batches, hasSize(1));
|
||||
|
||||
var embeddings = new ArrayList<TextEmbeddingFloatResults.Embedding>();
|
||||
|
@ -559,6 +813,10 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|||
return new ChunkedResultsListener();
|
||||
}
|
||||
|
||||
private static String getMatchedText(String text, ChunkedInference.TextOffset offset) {
|
||||
return text.substring(offset.start(), offset.end());
|
||||
}
|
||||
|
||||
private static class ChunkedResultsListener implements ActionListener<List<ChunkedInference>> {
|
||||
List<ChunkedInference> results;
|
||||
|
||||
|
|
|
@ -31,9 +31,9 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class AmazonBedrockActionCreatorTests extends ESTestCase {
|
||||
|
|
|
@ -34,13 +34,13 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -42,13 +42,13 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
|
|
@ -40,6 +40,8 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
|
@ -47,8 +49,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
|||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
|
|
@ -40,13 +40,13 @@ import java.util.Map;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
|
||||
|
|
|
@ -42,13 +42,13 @@ import java.net.URISyntaxException;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -39,13 +39,13 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -41,12 +41,12 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -45,13 +45,13 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -20,11 +20,11 @@ import org.elasticsearch.test.http.MockWebServer;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||
import org.junit.After;
|
||||
|
|
|
@ -38,12 +38,12 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests.createModel;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
|
|
|
@ -19,13 +19,13 @@ import org.elasticsearch.test.http.MockWebServer;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
|
||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||
|
@ -212,7 +212,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
|
||||
assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
|
||||
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||
|
@ -325,7 +325,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
|
||||
assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
|
||||
|
||||
assertThat(webServer.requests(), hasSize(2));
|
||||
{
|
||||
|
@ -383,7 +383,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(TextEmbeddingResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
|
||||
assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F }))));
|
||||
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
|
||||
|
|
|
@ -43,12 +43,12 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests.createModel;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
|
|
|
@ -34,6 +34,8 @@ import java.util.Map;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
|
@ -41,8 +43,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
|||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap;
|
||||
|
|
|
@ -43,6 +43,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
|
@ -52,7 +53,6 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
|
|
@ -39,6 +39,7 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
|
@ -46,7 +47,6 @@ import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiAct
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
|
|
@ -36,12 +36,12 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -44,14 +44,14 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationBinary;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationByte;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
|
|
@ -33,9 +33,9 @@ import java.nio.charset.CharacterCodingException;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
|
|
|
@ -31,11 +31,11 @@ import java.util.List;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
|
|
|
@ -44,13 +44,13 @@ import java.util.concurrent.TimeUnit;
|
|||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
|
|
@ -13,16 +13,16 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xcontent.XContentEOFException;
|
||||
import org.elasticsearch.xcontent.XContentParseException;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings;
|
||||
import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
|
|
@ -179,13 +179,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());
|
||||
assert elementType == DenseVectorFieldMapper.ElementType.BYTE || elementType == DenseVectorFieldMapper.ElementType.BIT;
|
||||
|
||||
List<TextEmbeddingByteResults.Chunk> chunks = new ArrayList<>();
|
||||
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
for (String input : inputs) {
|
||||
byte[] values = new byte[embeddingLength];
|
||||
for (int j = 0; j < values.length; j++) {
|
||||
values[j] = randomByte();
|
||||
}
|
||||
chunks.add(new TextEmbeddingByteResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length())));
|
||||
chunks.add(
|
||||
new EmbeddingResults.Chunk(
|
||||
new TextEmbeddingByteResults.Embedding(values),
|
||||
new ChunkedInference.TextOffset(0, input.length())
|
||||
)
|
||||
);
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
@ -195,13 +200,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());
|
||||
assert elementType == DenseVectorFieldMapper.ElementType.FLOAT;
|
||||
|
||||
List<TextEmbeddingFloatResults.Chunk> chunks = new ArrayList<>();
|
||||
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
for (String input : inputs) {
|
||||
float[] values = new float[embeddingLength];
|
||||
for (int j = 0; j < values.length; j++) {
|
||||
values[j] = randomFloat();
|
||||
}
|
||||
chunks.add(new TextEmbeddingFloatResults.Chunk(values, new ChunkedInference.TextOffset(0, input.length())));
|
||||
chunks.add(
|
||||
new EmbeddingResults.Chunk(
|
||||
new TextEmbeddingFloatResults.Embedding(values),
|
||||
new ChunkedInference.TextOffset(0, input.length())
|
||||
)
|
||||
);
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
@ -211,13 +221,18 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
}
|
||||
|
||||
public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingSparse(List<String> inputs, boolean withFloats) {
|
||||
List<SparseEmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
for (String input : inputs) {
|
||||
var tokens = new ArrayList<WeightedToken>();
|
||||
for (var token : input.split("\\s+")) {
|
||||
tokens.add(new WeightedToken(token, withFloats ? randomFloat() : randomIntBetween(1, 255)));
|
||||
}
|
||||
chunks.add(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(0, input.length())));
|
||||
chunks.add(
|
||||
new EmbeddingResults.Chunk(
|
||||
new SparseEmbeddingResults.Embedding(tokens, false),
|
||||
new ChunkedInference.TextOffset(0, input.length())
|
||||
)
|
||||
);
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
}
|
||||
|
@ -309,7 +324,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
) {
|
||||
switch (field.inference().modelSettings().taskType()) {
|
||||
case SPARSE_EMBEDDING -> {
|
||||
List<SparseEmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
List<EmbeddingResults.Chunk> chunks = new ArrayList<>();
|
||||
for (var entry : field.inference().chunks().entrySet()) {
|
||||
String entryField = entry.getKey();
|
||||
List<SemanticTextField.Chunk> entryChunks = entry.getValue();
|
||||
|
@ -320,7 +335,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
String matchedText = matchedTextIt.next();
|
||||
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, chunk, matchedText);
|
||||
var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType());
|
||||
chunks.add(new SparseEmbeddingResults.Chunk(tokens, offset));
|
||||
chunks.add(new EmbeddingResults.Chunk(new SparseEmbeddingResults.Embedding(tokens, false), offset));
|
||||
}
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
|
@ -343,11 +358,11 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
|
|||
String matchedText = matchedTextIt.next();
|
||||
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText);
|
||||
double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType());
|
||||
EmbeddingResults.Chunk chunk = switch (elementType) {
|
||||
case FLOAT -> new TextEmbeddingFloatResults.Chunk(FloatConversionUtils.floatArrayOf(values), offset);
|
||||
case BYTE, BIT -> new TextEmbeddingByteResults.Chunk(byteArrayOf(values), offset);
|
||||
EmbeddingResults.Embedding<?> embedding = switch (elementType) {
|
||||
case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
|
||||
case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values));
|
||||
};
|
||||
chunks.add(chunk);
|
||||
chunks.add(new EmbeddingResults.Chunk(embedding, offset));
|
||||
}
|
||||
}
|
||||
return new ChunkedInferenceEmbedding(chunks);
|
||||
|
|
|
@ -20,10 +20,10 @@ import org.elasticsearch.inference.Model;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
|
||||
import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
|
||||
import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
|
@ -953,7 +953,7 @@ public class ServiceUtilsTests extends ESTestCase {
|
|||
var model = mock(Model.class);
|
||||
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
||||
|
||||
var textEmbedding = TextEmbeddingResultsTests.createRandomResults();
|
||||
var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults();
|
||||
|
||||
doAnswer(invocation -> {
|
||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.elasticsearch.xcontent.XContentType;
|
|||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
|
@ -41,7 +42,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderT
|
|||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;
|
||||
|
|
|
@ -65,13 +65,13 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
|
||||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure;
|
||||
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap;
|
||||
|
@ -1458,10 +1458,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123F, 0.678F },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1470,10 +1470,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223F, 0.278F },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1205,10 +1205,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.0123f, -0.0123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1217,10 +1217,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 1.0123f, -1.0123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -72,7 +73,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap;
|
||||
|
@ -1355,10 +1355,10 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1367,10 +1367,10 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 1.123f, -1.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -75,7 +76,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
|
||||
|
@ -1468,7 +1468,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1479,7 +1479,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1565,16 +1565,22 @@ public class CohereServiceTests extends ESTestCase {
|
|||
var byteResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(byteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), byteResult.chunks().get(0).offset());
|
||||
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
|
||||
assertArrayEquals(new byte[] { 23, -23 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding());
|
||||
assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new byte[] { 23, -23 },
|
||||
((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
|
||||
);
|
||||
}
|
||||
{
|
||||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var byteResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(byteResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), byteResult.chunks().get(0).offset());
|
||||
assertThat(byteResult.chunks().get(0), instanceOf(TextEmbeddingByteResults.Chunk.class));
|
||||
assertArrayEquals(new byte[] { 24, -24 }, ((TextEmbeddingByteResults.Chunk) byteResult.chunks().get(0)).embedding());
|
||||
assertThat(byteResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new byte[] { 24, -24 },
|
||||
((TextEmbeddingByteResults.Embedding) byteResult.chunks().get(0).embedding()).values()
|
||||
);
|
||||
}
|
||||
|
||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
||||
|
|
|
@ -37,7 +37,9 @@ import org.elasticsearch.xcontent.XContentFactory;
|
|||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
|
||||
import org.elasticsearch.xpack.inference.InferencePlugin;
|
||||
|
@ -48,7 +50,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
|||
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
|
||||
|
@ -645,8 +646,11 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
|||
sparseResult.chunks(),
|
||||
is(
|
||||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
||||
new EmbeddingResults.Chunk(
|
||||
new SparseEmbeddingResults.Embedding(
|
||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
||||
false
|
||||
),
|
||||
new ChunkedInference.TextOffset(0, "input text".length())
|
||||
)
|
||||
)
|
||||
|
@ -767,8 +771,11 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
|||
sparseResult.chunks(),
|
||||
is(
|
||||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
||||
new EmbeddingResults.Chunk(
|
||||
new SparseEmbeddingResults.Embedding(
|
||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
||||
false
|
||||
),
|
||||
new ChunkedInference.TextOffset(0, "input text".length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -896,20 +896,20 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
|
||||
assertThat(result1.chunks(), hasSize(1));
|
||||
assertThat(result1.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(result1.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(),
|
||||
((TextEmbeddingFloatResults.Chunk) result1.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) result1.chunks().get(0).embedding()).values(),
|
||||
0.0001f
|
||||
);
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
|
||||
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
|
||||
assertThat(result2.chunks(), hasSize(1));
|
||||
assertThat(result2.chunks().get(0), instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(result2.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(),
|
||||
((TextEmbeddingFloatResults.Chunk) result2.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) result2.chunks().get(0).embedding()).values(),
|
||||
0.0001f
|
||||
);
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
|
||||
|
@ -972,18 +972,18 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
assertThat(chunkedResponse, hasSize(2));
|
||||
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
|
||||
assertThat(result1.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
|
||||
assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
assertEquals(
|
||||
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens()
|
||||
((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens()
|
||||
);
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
|
||||
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
|
||||
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
|
||||
assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
assertEquals(
|
||||
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens()
|
||||
((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens()
|
||||
);
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
|
||||
gotResults.set(true);
|
||||
|
@ -1044,18 +1044,18 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
assertThat(chunkedResponse, hasSize(2));
|
||||
assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0);
|
||||
assertThat(result1.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
|
||||
assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
assertEquals(
|
||||
((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result1.chunks().get(0)).weightedTokens()
|
||||
((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens()
|
||||
);
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset());
|
||||
assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1);
|
||||
assertThat(result2.chunks().get(0), instanceOf(SparseEmbeddingResults.Chunk.class));
|
||||
assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class));
|
||||
assertEquals(
|
||||
((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(),
|
||||
((SparseEmbeddingResults.Chunk) result2.chunks().get(0)).weightedTokens()
|
||||
((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens()
|
||||
);
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset());
|
||||
gotResults.set(true);
|
||||
|
|
|
@ -64,6 +64,7 @@ import java.util.stream.Collectors;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -72,7 +73,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
|
||||
|
@ -882,11 +882,11 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.0123f, -0.0123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -897,11 +897,11 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.0456f, -0.0456f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.xcontent.ToXContent;
|
|||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
|
@ -109,8 +110,8 @@ public class HuggingFaceElserServiceTests extends ESTestCase {
|
|||
sparseResult.chunks(),
|
||||
is(
|
||||
List.of(
|
||||
new SparseEmbeddingResults.Chunk(
|
||||
List.of(new WeightedToken(".", 0.13315596f)),
|
||||
new EmbeddingResults.Chunk(
|
||||
new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(".", 0.13315596f)), false),
|
||||
new ChunkedInference.TextOffset(0, "abc".length())
|
||||
)
|
||||
)
|
||||
|
|
|
@ -35,12 +35,12 @@ import org.elasticsearch.xcontent.ToXContent;
|
|||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
|
@ -58,13 +58,13 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
|
||||
|
@ -788,10 +788,10 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
var embeddingResult = (ChunkedInferenceEmbedding) result;
|
||||
assertThat(embeddingResult.chunks(), hasSize(1));
|
||||
assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length())));
|
||||
assertThat(embeddingResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(embeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { -0.0123f, 0.0123f },
|
||||
((TextEmbeddingFloatResults.Chunk) embeddingResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) embeddingResult.chunks().get(0).embedding()).values(),
|
||||
0.001f
|
||||
);
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
|
@ -842,10 +842,10 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 3), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
|
|
@ -68,6 +68,7 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -76,7 +77,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
|
||||
|
@ -734,11 +734,11 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.0123f, -0.0123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -749,11 +749,11 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.0456f, -0.0456f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -71,7 +72,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
|
@ -1833,10 +1833,10 @@ public class JinaAIServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1845,10 +1845,10 @@ public class JinaAIServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
|
|
@ -684,11 +684,11 @@ public class MistralServiceTests extends ESTestCase {
|
|||
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -696,11 +696,11 @@ public class MistralServiceTests extends ESTestCase {
|
|||
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
|
||||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -72,6 +72,7 @@ import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
|||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
|
||||
|
@ -81,7 +82,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap;
|
||||
|
@ -1871,11 +1871,11 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
@ -1884,11 +1884,11 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset());
|
||||
assertThat(floatResult.chunks().get(0), Matchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertTrue(
|
||||
Arrays.equals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().get(0)).embedding()
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import org.elasticsearch.inference.InferenceService;
|
|||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.junit.Before;
|
||||
import org.mockito.Mock;
|
||||
|
||||
|
|
|
@ -14,12 +14,12 @@ import org.elasticsearch.inference.InferenceServiceResults;
|
|||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.EmptyTaskSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
|
||||
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
|
||||
import org.junit.Before;
|
||||
import org.mockito.Mock;
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -69,7 +70,6 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
|
|||
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
|
@ -1840,10 +1840,10 @@ public class VoyageAIServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.getFirst();
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().getFirst().offset());
|
||||
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().get(0).embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.123f, -0.123f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().getFirst()).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
@ -1852,10 +1852,10 @@ public class VoyageAIServiceTests extends ESTestCase {
|
|||
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||
assertThat(floatResult.chunks(), hasSize(1));
|
||||
assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().getFirst().offset());
|
||||
assertThat(floatResult.chunks().getFirst(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Chunk.class));
|
||||
assertThat(floatResult.chunks().getFirst().embedding(), CoreMatchers.instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||
assertArrayEquals(
|
||||
new float[] { 0.223f, -0.223f },
|
||||
((TextEmbeddingFloatResults.Chunk) floatResult.chunks().getFirst()).embedding(),
|
||||
((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(),
|
||||
0.0f
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue