[ML] Inference API rate limit queuing logic refactor (#107706)

* Adding new executor

* Adding in queuing logic

* working tests

* Added cleanup task

* Update docs/changelog/107706.yaml

* Updating yml

* deregistering callbacks for settings changes

* Cleaning up code

* Update docs/changelog/107706.yaml

* Fixing rate limit settings bug and only sleeping least amount

* Removing debug logging

* Removing commented code

* Renaming feedback

* fixing tests

* Updating docs and validation

* Fixing source blocks

* Adjusting cancel logic

* Reformatting ascii

* Addressing feedback

* adding rate limiting for google embeddings and mistral

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Jonathan Buttner 2024-06-05 08:25:25 -04:00 committed by GitHub
parent cd84749d87
commit fdb5058b13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
102 changed files with 1499 additions and 937 deletions

View file

@ -0,0 +1,5 @@
pr: 107706
summary: Add rate limiting support for the Inference API
area: Machine Learning
type: enhancement
issues: []

View file

@ -7,21 +7,17 @@ experimental[]
Creates an {infer} endpoint to perform an {infer} task.
IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure
OpenAI, Google AI Studio or Hugging Face. For built-in models and models
uploaded though Eland, the {infer} APIs offer an alternative way to use and
manage trained models. However, if you do not plan to use the {infer} APIs to
use these models or if you want to use non-NLP models, use the
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face.
For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models.
However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the
<<ml-df-trained-models-apis>>.
[discrete]
[[put-inference-api-request]]
==== {api-request-title}
`PUT /_inference/<task_type>/<inference_id>`
[discrete]
[[put-inference-api-prereqs]]
==== {api-prereq-title}
@ -29,7 +25,6 @@ use these models or if you want to use non-NLP models, use the
* Requires the `manage_inference` <<privileges-list-cluster,cluster privilege>>
(the built-in `inference_admin` role grants this privilege)
[discrete]
[[put-inference-api-desc]]
==== {api-description-title}
@ -48,25 +43,23 @@ The following services are available through the {infer} API:
* Hugging Face
* OpenAI
[discrete]
[[put-inference-api-path-params]]
==== {api-path-parms-title}
`<inference_id>`::
(Required, string)
The unique identifier of the {infer} endpoint.
`<task_type>`::
(Required, string)
The type of the {infer} task that the model will perform. Available task types:
The type of the {infer} task that the model will perform.
Available task types:
* `completion`,
* `rerank`,
* `sparse_embedding`,
* `text_embedding`.
[discrete]
[[put-inference-api-request-body]]
==== {api-request-body-title}
@ -78,21 +71,18 @@ Available services:
* `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service.
* `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service.
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the
Cohere service.
* `elasticsearch`: specify the `text_embedding` task type to use the E5
built-in model or text embedding models uploaded by Eland.
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service.
* `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland.
* `elser`: specify the `sparse_embedding` task type to use the ELSER service.
* `googleaistudio`: specify the `completion` task to use the Google AI Studio service.
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face
service.
* `openai`: specify the `completion` or `text_embedding` task type to use the
OpenAI service.
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service.
* `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service.
`service_settings`::
(Required, object)
Settings used to install the {infer} model. These settings are specific to the
Settings used to install the {infer} model.
These settings are specific to the
`service` you specified.
+
.`service_settings` for the `azureaistudio` service
@ -104,11 +94,10 @@ Settings used to install the {infer} model. These settings are specific to the
A valid API key of your Azure AI Studio model deployment.
This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account.
IMPORTANT: You need to provide the API key only once, during the {infer} model
creation. The <<get-inference-api>> does not retrieve your API key. After
creating the {infer} model, you cannot change the associated API key. If you
want to use a different API key, delete the {infer} model and recreate it with
the same name and the updated API key.
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
The <<get-inference-api>> does not retrieve your API key.
After creating the {infer} model, you cannot change the associated API key.
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
`target`:::
(Required, string)
@ -142,11 +131,13 @@ For "real-time" endpoints which are billed per hour of usage, specify `realtime`
By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`.
This helps to minimize the number of rate limit errors returned from Azure AI Studio.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
```
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
```
----
=====
+
.`service_settings` for the `azureopenai` service
@ -181,6 +172,22 @@ Your Azure OpenAI deployments can be found though the https://oai.azure.com/[Azu
The Azure API version ID to use.
We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version].
`rate_limit`:::
(Optional, object)
The `azureopenai` service sets a default number of requests allowed per minute depending on the task type.
For `text_embedding` it is set to `1440`.
For `completion` it is set to `120`.
This helps to minimize the number of rate limit errors returned from Azure.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
+
More information about the rate limits for Azure can be found in the https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits[Quota limits docs] and https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest[How to change the quotas].
=====
+
.`service_settings` for the `cohere` service
@ -188,24 +195,24 @@ We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/opena
=====
`api_key`:::
(Required, string)
A valid API key of your Cohere account. You can find your Cohere API keys or you
can create a new one
A valid API key of your Cohere account.
You can find your Cohere API keys or you can create a new one
https://dashboard.cohere.com/api-keys[on the API keys settings page].
IMPORTANT: You need to provide the API key only once, during the {infer} model
creation. The <<get-inference-api>> does not retrieve your API key. After
creating the {infer} model, you cannot change the associated API key. If you
want to use a different API key, delete the {infer} model and recreate it with
the same name and the updated API key.
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
The <<get-inference-api>> does not retrieve your API key.
After creating the {infer} model, you cannot change the associated API key.
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
`embedding_type`::
(Optional, string)
Only for `text_embedding`. Specifies the types of embeddings you want to get
back. Defaults to `float`.
Only for `text_embedding`.
Specifies the types of embeddings you want to get back.
Defaults to `float`.
Valid values are:
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
* `float`: use it for the default float embeddings.
* `int8`: use it for signed int8 embeddings.
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
* `float`: use it for the default float embeddings.
* `int8`: use it for signed int8 embeddings.
`model_id`::
(Optional, string)
@ -214,50 +221,68 @@ To review the available `rerank` models, refer to the
https://docs.cohere.com/reference/rerank-1[Cohere docs].
To review the available `text_embedding` models, refer to the
https://docs.cohere.com/reference/embed[Cohere docs]. The default value for
https://docs.cohere.com/reference/embed[Cohere docs].
The default value for
`text_embedding` is `embed-english-v2.0`.
`rate_limit`:::
(Optional, object)
By default, the `cohere` service sets the number of requests allowed per minute to `10000`.
This value is the same for all task types.
This helps to minimize the number of rate limit errors returned from Cohere.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
+
More information about Cohere's rate limits can be found in https://docs.cohere.com/docs/going-live#production-key-specifications[Cohere's production key docs].
=====
+
.`service_settings` for the `elasticsearch` service
[%collapsible%closed]
=====
`model_id`:::
(Required, string)
The name of the model to use for the {infer} task. It can be the
ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or
a text embedding model already
The name of the model to use for the {infer} task.
It can be the ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or a text embedding model already
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
`num_allocations`:::
(Required, integer)
The number of model allocations to create. `num_allocations` must not exceed the
number of available processors per node divided by the `num_threads`.
The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
`num_threads`:::
(Required, integer)
The number of threads to use by each model allocation. `num_threads` must not
exceed the number of available processors per node divided by the number of
allocations. Must be a power of 2. Max allowed value is 32.
The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
Must be a power of 2. Max allowed value is 32.
=====
+
.`service_settings` for the `elser` service
[%collapsible%closed]
=====
`num_allocations`:::
(Required, integer)
The number of model allocations to create. `num_allocations` must not exceed the
number of available processors per node divided by the `num_threads`.
The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
`num_threads`:::
(Required, integer)
The number of threads to use by each model allocation. `num_threads` must not
exceed the number of available processors per node divided by the number of
allocations. Must be a power of 2. Max allowed value is 32.
The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
Must be a power of 2. Max allowed value is 32.
=====
+
.`service_settings` for the `googleiastudio` service
[%collapsible%closed]
=====
`api_key`:::
(Required, string)
A valid API key for the Google Gemini API.
@ -274,76 +299,113 @@ This helps to minimize the number of rate limit errors returned from Google AI S
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
--
```
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
```
----
--
=====
+
.`service_settings` for the `hugging_face` service
[%collapsible%closed]
=====
`api_key`:::
(Required, string)
A valid access token of your Hugging Face account. You can find your Hugging
Face access tokens or you can create a new one
A valid access token of your Hugging Face account.
You can find your Hugging Face access tokens or you can create a new one
https://huggingface.co/settings/tokens[on the settings page].
IMPORTANT: You need to provide the API key only once, during the {infer} model
creation. The <<get-inference-api>> does not retrieve your API key. After
creating the {infer} model, you cannot change the associated API key. If you
want to use a different API key, delete the {infer} model and recreate it with
the same name and the updated API key.
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
The <<get-inference-api>> does not retrieve your API key.
After creating the {infer} model, you cannot change the associated API key.
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
`url`:::
(Required, string)
The URL endpoint to use for the requests.
`rate_limit`:::
(Optional, object)
By default, the `huggingface` service sets the number of requests allowed per minute to `3000`.
This helps to minimize the number of rate limit errors returned from Hugging Face.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
=====
+
.`service_settings` for the `openai` service
[%collapsible%closed]
=====
`api_key`:::
(Required, string)
A valid API key of your OpenAI account. You can find your OpenAI API keys in
your OpenAI account under the
A valid API key of your OpenAI account.
You can find your OpenAI API keys in your OpenAI account under the
https://platform.openai.com/api-keys[API keys section].
IMPORTANT: You need to provide the API key only once, during the {infer} model
creation. The <<get-inference-api>> does not retrieve your API key. After
creating the {infer} model, you cannot change the associated API key. If you
want to use a different API key, delete the {infer} model and recreate it with
the same name and the updated API key.
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
The <<get-inference-api>> does not retrieve your API key.
After creating the {infer} model, you cannot change the associated API key.
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
`model_id`:::
(Required, string)
The name of the model to use for the {infer} task. Refer to the
The name of the model to use for the {infer} task.
Refer to the
https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation]
for the list of available text embedding models.
`organization_id`:::
(Optional, string)
The unique identifier of your organization. You can find the Organization ID in
your OpenAI account under
The unique identifier of your organization.
You can find the Organization ID in your OpenAI account under
https://platform.openai.com/account/organization[**Settings** > **Organizations**].
`url`:::
(Optional, string)
The URL endpoint to use for the requests. Can be changed for testing purposes.
The URL endpoint to use for the requests.
Can be changed for testing purposes.
Defaults to `https://api.openai.com/v1/embeddings`.
`rate_limit`:::
(Optional, object)
The `openai` service sets a default number of requests allowed per minute depending on the task type.
For `text_embedding` it is set to `3000`.
For `completion` it is set to `500`.
This helps to minimize the number of rate limit errors returned from Azure.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
+
More information about the rate limits for OpenAI can be found in your https://platform.openai.com/account/limits[Account limits].
=====
`task_settings`::
(Optional, object)
Settings to configure the {infer} task. These settings are specific to the
Settings to configure the {infer} task.
These settings are specific to the
`<task_type>` you specified.
+
.`task_settings` for the `completion` task type
[%collapsible%closed]
=====
`do_sample`:::
(Optional, float)
For the `azureaistudio` service only.
@ -358,8 +420,8 @@ Defaults to 64.
`user`:::
(Optional, string)
For `openai` service only. Specifies the user issuing the request, which can be
used for abuse detection.
For `openai` service only.
Specifies the user issuing the request, which can be used for abuse detection.
`temperature`:::
(Optional, float)
@ -378,45 +440,46 @@ Should not be used if `temperature` is specified.
.`task_settings` for the `rerank` task type
[%collapsible%closed]
=====
`return_documents`::
(Optional, boolean)
For `cohere` service only. Specify whether to return doc text within the
results.
For `cohere` service only.
Specify whether to return doc text within the results.
`top_n`::
(Optional, integer)
The number of most relevant documents to return, defaults to the number of the
documents.
The number of most relevant documents to return, defaults to the number of the documents.
=====
+
.`task_settings` for the `text_embedding` task type
[%collapsible%closed]
=====
`input_type`:::
(Optional, string)
For `cohere` service only. Specifies the type of input passed to the model.
For `cohere` service only.
Specifies the type of input passed to the model.
Valid values are:
* `classification`: use it for embeddings passed through a text classifier.
* `clusterning`: use it for the embeddings run through a clustering algorithm.
* `ingest`: use it for storing document embeddings in a vector database.
* `search`: use it for storing embeddings of search queries run against a
vector database to find relevant documents.
* `classification`: use it for embeddings passed through a text classifier.
* `clusterning`: use it for the embeddings run through a clustering algorithm.
* `ingest`: use it for storing document embeddings in a vector database.
* `search`: use it for storing embeddings of search queries run against a vector database to find relevant documents.
`truncate`:::
(Optional, string)
For `cohere` service only. Specifies how the API handles inputs longer than the
maximum token length. Defaults to `END`. Valid values are:
* `NONE`: when the input exceeds the maximum input token length an error is
returned.
* `START`: when the input exceeds the maximum input token length the start of
the input is discarded.
* `END`: when the input exceeds the maximum input token length the end of
the input is discarded.
For `cohere` service only.
Specifies how the API handles inputs longer than the maximum token length.
Defaults to `END`.
Valid values are:
* `NONE`: when the input exceeds the maximum input token length an error is returned.
* `START`: when the input exceeds the maximum input token length the start of the input is discarded.
* `END`: when the input exceeds the maximum input token length the end of the input is discarded.
`user`:::
(optional, string)
For `openai`, `azureopenai` and `azureaistudio` services only. Specifies the user issuing the
request, which can be used for abuse detection.
For `openai`, `azureopenai` and `azureaistudio` services only.
Specifies the user issuing the request, which can be used for abuse detection.
=====
[discrete]
@ -470,7 +533,6 @@ PUT _inference/completion/azure_ai_studio_completion
The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer].
[discrete]
[[inference-example-azureopenai]]
===== Azure OpenAI service
@ -519,7 +581,6 @@ The list of chat completion models that you can choose from in your Azure OpenAI
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models]
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5]
[discrete]
[[inference-example-cohere]]
===== Cohere service
@ -565,7 +626,6 @@ PUT _inference/rerank/cohere-rerank
For more examples, also review the
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
[discrete]
[[inference-example-e5]]
===== E5 via the `elasticsearch` service
@ -586,10 +646,9 @@ PUT _inference/text_embedding/my-e5-model
}
------------------------------------------------------------
// TEST[skip:TBD]
<1> The `model_id` must be the ID of one of the built-in E5 models. Valid values
are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. For
further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
<1> The `model_id` must be the ID of one of the built-in E5 models.
Valid values are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`.
For further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
[discrete]
[[inference-example-elser]]
@ -597,8 +656,7 @@ further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
The following example shows how to create an {infer} endpoint called
`my-elser-model` to perform a `sparse_embedding` task type.
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more
info.
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more info.
[source,console]
------------------------------------------------------------
@ -672,16 +730,17 @@ PUT _inference/text_embedding/hugging-face-embeddings
}
------------------------------------------------------------
// TEST[skip:TBD]
<1> A valid Hugging Face access token. You can find on the
<1> A valid Hugging Face access token.
You can find on the
https://huggingface.co/settings/tokens[settings page of your account].
<2> The {infer} endpoint URL you created on Hugging Face.
Create a new {infer} endpoint on
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an
endpoint URL. Select the model you want to use on the new endpoint creation page
- for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
task under the Advanced configuration section. Create the endpoint. Copy the URL
after the endpoint initialization has been finished.
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an endpoint URL.
Select the model you want to use on the new endpoint creation page - for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
task under the Advanced configuration section.
Create the endpoint.
Copy the URL after the endpoint initialization has been finished.
[discrete]
[[inference-example-hugging-face-supported-models]]
@ -695,7 +754,6 @@ The list of recommended models for the Hugging Face service:
* https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base]
* https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small]
[discrete]
[[inference-example-eland]]
===== Models uploaded by Eland via the elasticsearch service
@ -716,11 +774,9 @@ PUT _inference/text_embedding/my-msmarco-minilm-model
}
------------------------------------------------------------
// TEST[skip:TBD]
<1> The `model_id` must be the ID of a text embedding model which has already
been
<1> The `model_id` must be the ID of a text embedding model which has already been
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
[discrete]
[[inference-example-openai]]
===== OpenAI service
@ -756,4 +812,3 @@ PUT _inference/completion/openai-completion
}
------------------------------------------------------------
// TEST[skip:TBD]

View file

@ -88,6 +88,13 @@ public class TimeValue implements Comparable<TimeValue> {
return new TimeValue(days, TimeUnit.DAYS);
}
/**
* @return the {@link TimeValue} object that has the least duration.
*/
public static TimeValue min(TimeValue time1, TimeValue time2) {
return time1.compareTo(time2) < 0 ? time1 : time2;
}
/**
* @return the unit used for the this time value, see {@link #duration()}
*/

View file

@ -17,6 +17,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.object.HasToString.hasToString;
@ -231,6 +232,12 @@ public class TimeValueTests extends ESTestCase {
assertThat(ex.getMessage(), containsString("duration cannot be negative"));
}
public void testMin() {
assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0)));
assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1)));
assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE));
}
private TimeUnit randomTimeUnitObject() {
return randomFrom(
TimeUnit.NANOSECONDS,

View file

@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor {
private final ServiceComponents serviceComponents;
public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
// TODO Batching - accept a class that can handle batching
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

View file

@ -36,6 +36,7 @@ public class CohereEmbeddingsAction implements ExecutableAction {
model.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -37,17 +36,16 @@ public class AzureAiStudioChatCompletionRequestManager extends AzureAiStudioRequ
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
private static ResponseHandler createCompletionHandler() {

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -41,17 +40,16 @@ public class AzureAiStudioEmbeddingsRequestManager extends AzureAiStudioRequestM
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
private static ResponseHandler createEmbeddingsHandler() {

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -43,16 +42,15 @@ public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManag
}
@Override
public Runnable create(
public void execute(
@Nullable String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -55,16 +54,15 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -38,11 +38,16 @@ abstract class BaseRequestManager implements RequestManager {
@Override
public Object rateLimitGrouping() {
return rateLimitGroup;
// It's possible that two inference endpoints have the same information defining the group but have different
// rate limits then they should be in different groups otherwise whoever initially created the group will set
// the rate and the other inference endpoint's rate will be ignored
return new EndpointGrouping(rateLimitGroup, rateLimitSettings);
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
private record EndpointGrouping(Object group, RateLimitSettings settings) {}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -46,16 +45,15 @@ public class CohereCompletionRequestManager extends CohereRequestManager {
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereCompletionRequest request = new CohereCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -44,16 +43,15 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager {
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -44,16 +43,15 @@ public class CohereRerankRequestManager extends CohereRequestManager {
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereRerankRequest request = new CohereRerankRequest(query, input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -23,7 +23,6 @@ record ExecutableInferenceRequest(
RequestSender requestSender,
Logger logger,
Request request,
HttpClientContext context,
ResponseHandler responseHandler,
Supplier<Boolean> hasFinished,
ActionListener<InferenceServiceResults> listener
@ -34,7 +33,7 @@ record ExecutableInferenceRequest(
var inferenceEntityId = request.createHttpRequest().inferenceEntityId();
try {
requestSender.send(logger, request, context, hasFinished, responseHandler, listener);
requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener);
} catch (Exception e) {
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
logger.warn(errorMessage, e);

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -42,15 +41,14 @@ public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioReques
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -48,17 +47,16 @@ public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioReques
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -15,6 +15,8 @@ import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@ -39,30 +41,28 @@ public class HttpRequestSender implements Sender {
private final ServiceComponents serviceComponents;
private final HttpClientManager httpClientManager;
private final ClusterService clusterService;
private final SingleRequestManager requestManager;
private final RequestSender requestSender;
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.httpClientManager = Objects.requireNonNull(httpClientManager);
this.clusterService = Objects.requireNonNull(clusterService);
var requestSender = new RetryingHttpSender(
requestSender = new RetryingHttpSender(
this.httpClientManager.getHttpClient(),
serviceComponents.throttlerManager(),
new RetrySettings(serviceComponents.settings(), clusterService),
serviceComponents.threadPool()
);
requestManager = new SingleRequestManager(requestSender);
}
public Sender createSender(String serviceName) {
public Sender createSender() {
return new HttpRequestSender(
serviceName,
serviceComponents.threadPool(),
httpClientManager,
clusterService,
serviceComponents.settings(),
requestManager
requestSender
);
}
}
@ -71,26 +71,24 @@ public class HttpRequestSender implements Sender {
private final ThreadPool threadPool;
private final HttpClientManager manager;
private final RequestExecutorService service;
private final RequestExecutor service;
private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(1);
private HttpRequestSender(
String serviceName,
ThreadPool threadPool,
HttpClientManager httpClientManager,
ClusterService clusterService,
Settings settings,
SingleRequestManager requestManager
RequestSender requestSender
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
service = new RequestExecutorService(
serviceName,
threadPool,
startCompleted,
new RequestExecutorServiceSettings(settings, clusterService),
requestManager
requestSender
);
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -55,26 +54,17 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(
requestSender,
logger,
request,
context,
responseHandler,
hasRequestCompletedFunction,
listener
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}
record RateLimitGrouping(int accountHash) {

View file

@ -19,9 +19,9 @@ import java.util.function.Supplier;
public interface InferenceRequest {
/**
* Returns the creator that handles building an executable request based on the input provided.
* Returns the manager that handles building and executing an inference request.
*/
RequestManager getRequestCreator();
RequestManager getRequestManager();
/**
* Returns the query associated with this request. Used for Rerank tasks.

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -51,18 +50,17 @@ public class MistralEmbeddingsRequestManager extends BaseRequestManager {
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
record RateLimitGrouping(int keyHashCode) {

View file

@ -1,52 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import java.util.List;
import java.util.function.Supplier;
class NoopTask implements RejectableTask {
@Override
public RequestManager getRequestCreator() {
return null;
}
@Override
public String getQuery() {
return null;
}
@Override
public List<String> getInput() {
return null;
}
@Override
public ActionListener<InferenceServiceResults> getListener() {
return null;
}
@Override
public boolean hasCompleted() {
return true;
}
@Override
public Supplier<Boolean> getRequestCompletedFunction() {
return () -> true;
}
@Override
public void onRejection(Exception e) {
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -43,17 +42,16 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager {
}
@Override
public Runnable create(
public void execute(
@Nullable String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
private static ResponseHandler createCompletionHandler() {

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -55,17 +54,16 @@ public class OpenAiEmbeddingsRequestManager extends OpenAiRequestManager {
}
@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
@ -17,21 +16,31 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue;
import org.elasticsearch.xpack.inference.common.RateLimiter;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
/**
* A service for queuing and executing {@link RequestTask}. This class is useful because the
@ -45,7 +54,18 @@ import static org.elasticsearch.core.Strings.format;
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
*/
class RequestExecutorService implements RequestExecutor {
private static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> QUEUE_CREATOR =
/**
* Provides dependency injection mainly for testing
*/
interface Sleeper {
void sleep(TimeValue sleepTime) throws InterruptedException;
}
// default for tests
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
// default for tests
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@Override
public BlockingQueue<RejectableTask> create(int capacity) {
@ -65,86 +85,116 @@ class RequestExecutorService implements RequestExecutor {
}
};
/**
* Provides dependency injection mainly for testing
*/
interface RateLimiterCreator {
RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit);
}
// default for testing
static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new;
private static final Logger logger = LogManager.getLogger(RequestExecutorService.class);
private final String serviceName;
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
private final AtomicBoolean running = new AtomicBoolean(true);
private final CountDownLatch terminationLatch = new CountDownLatch(1);
private final HttpClientContext httpContext;
private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
private final ConcurrentMap<Object, RateLimitingEndpointHandler> rateLimitGroupings = new ConcurrentHashMap<>();
private final ThreadPool threadPool;
private final CountDownLatch startupLatch;
private final BlockingQueue<Runnable> controlQueue = new LinkedBlockingQueue<>();
private final SingleRequestManager requestManager;
private final CountDownLatch terminationLatch = new CountDownLatch(1);
private final RequestSender requestSender;
private final RequestExecutorServiceSettings settings;
private final Clock clock;
private final AtomicBoolean shutdown = new AtomicBoolean(false);
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
private final Sleeper sleeper;
private final RateLimiterCreator rateLimiterCreator;
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
private final AtomicBoolean started = new AtomicBoolean(false);
RequestExecutorService(
String serviceName,
ThreadPool threadPool,
@Nullable CountDownLatch startupLatch,
RequestExecutorServiceSettings settings,
SingleRequestManager requestManager
RequestSender requestSender
) {
this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager);
}
/**
* This constructor should only be used directly for testing.
*/
RequestExecutorService(
String serviceName,
ThreadPool threadPool,
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
@Nullable CountDownLatch startupLatch,
RequestExecutorServiceSettings settings,
SingleRequestManager requestManager
) {
this.serviceName = Objects.requireNonNull(serviceName);
this.threadPool = Objects.requireNonNull(threadPool);
this.httpContext = HttpClientContext.create();
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
this.startupLatch = startupLatch;
this.requestManager = Objects.requireNonNull(requestManager);
Objects.requireNonNull(settings);
settings.registerQueueCapacityCallback(this::onCapacityChange);
}
private void onCapacityChange(int capacity) {
logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity));
var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity));
if (enqueuedCapacityCommand == false) {
logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later.");
} else {
// ensure that the task execution loop wakes up
queue.offer(new NoopTask());
}
}
private void updateCapacity(int newCapacity) {
try {
queue.setCapacity(newCapacity);
} catch (Exception e) {
logger.warn(
format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName),
e
this(
threadPool,
DEFAULT_QUEUE_CREATOR,
startupLatch,
settings,
requestSender,
Clock.systemUTC(),
DEFAULT_SLEEPER,
DEFAULT_RATE_LIMIT_CREATOR
);
}
RequestExecutorService(
ThreadPool threadPool,
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator,
@Nullable CountDownLatch startupLatch,
RequestExecutorServiceSettings settings,
RequestSender requestSender,
Clock clock,
Sleeper sleeper,
RateLimiterCreator rateLimiterCreator
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.queueCreator = Objects.requireNonNull(queueCreator);
this.startupLatch = startupLatch;
this.requestSender = Objects.requireNonNull(requestSender);
this.settings = Objects.requireNonNull(settings);
this.clock = Objects.requireNonNull(clock);
this.sleeper = Objects.requireNonNull(sleeper);
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
}
public void shutdown() {
if (shutdown.compareAndSet(false, true)) {
if (cancellableCleanupTask.get() != null) {
logger.debug(() -> "Stopping clean up thread");
cancellableCleanupTask.get().cancel();
}
}
}
public boolean isShutdown() {
return shutdown.get();
}
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return terminationLatch.await(timeout, unit);
}
public boolean isTerminated() {
return terminationLatch.getCount() == 0;
}
public int queueSize() {
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum();
}
/**
* Begin servicing tasks.
* <p>
* <b>Note: This should only be called once for the life of the object.</b>
* </p>
*/
public void start() {
try {
assert started.get() == false : "start() can only be called once";
started.set(true);
startCleanupTask();
signalStartInitiated();
while (running.get()) {
while (isShutdown() == false) {
handleTasks();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
running.set(false);
shutdown();
notifyRequestsOfShutdown();
terminationLatch.countDown();
}
@ -156,108 +206,68 @@ class RequestExecutorService implements RequestExecutor {
}
}
/**
* Protects the task retrieval logic from an unexpected exception.
*
* @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to
* shut down
*/
private void startCleanupTask() {
assert cancellableCleanupTask.get() == null : "The clean up task can only be set once";
cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL));
}
private Scheduler.Cancellable startCleanupThread(TimeValue interval) {
logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval));
return threadPool.scheduleWithFixedDelay(this::removeStaleGroupings, interval, threadPool.executor(UTILITY_THREAD_POOL_NAME));
}
// default for testing
void removeStaleGroupings() {
var now = Instant.now(clock);
for (var iter = rateLimitGroupings.values().iterator(); iter.hasNext();) {
var endpoint = iter.next();
// if the current time is after the last time the endpoint enqueued a request + allowed stale period then we'll remove it
if (now.isAfter(endpoint.timeOfLastEnqueue().plus(settings.getRateLimitGroupStaleDuration()))) {
endpoint.close();
iter.remove();
}
}
}
private void handleTasks() throws InterruptedException {
try {
RejectableTask task = queue.take();
var command = controlQueue.poll();
if (command != null) {
command.run();
var timeToWait = settings.getTaskPollFrequency();
for (var endpoint : rateLimitGroupings.values()) {
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
}
// TODO add logic to complete pending items in the queue before shutting down
if (running.get() == false) {
logger.debug(() -> format("Http executor service [%s] exiting", serviceName));
rejectTaskBecauseOfShutdown(task);
} else {
executeTask(task);
}
} catch (InterruptedException e) {
throw e;
} catch (Exception e) {
logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e);
}
sleeper.sleep(timeToWait);
}
private void executeTask(RejectableTask task) {
try {
requestManager.execute(task, httpContext);
} catch (Exception e) {
logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e);
}
}
private synchronized void notifyRequestsOfShutdown() {
private void notifyRequestsOfShutdown() {
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
try {
List<RejectableTask> notExecuted = new ArrayList<>();
queue.drainTo(notExecuted);
rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown);
} catch (Exception e) {
logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName));
for (var endpoint : rateLimitGroupings.values()) {
endpoint.notifyRequestsOfShutdown();
}
}
private void rejectTaskBecauseOfShutdown(RejectableTask task) {
try {
task.onRejection(
new EsRejectedExecutionException(
format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName),
true
)
);
} catch (Exception e) {
logger.warn(
format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName)
);
}
// default for testing
Integer remainingQueueCapacity(RequestManager requestManager) {
var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping());
if (endpoint == null) {
return null;
}
private void rejectTasks(List<RejectableTask> tasks, Consumer<RejectableTask> rejectionFunction) {
for (var task : tasks) {
rejectionFunction.accept(task);
}
return endpoint.remainingCapacity();
}
public int queueSize() {
return queue.size();
}
@Override
public void shutdown() {
if (running.compareAndSet(true, false)) {
// if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns
queue.offer(new NoopTask());
}
}
@Override
public boolean isShutdown() {
return running.get() == false;
}
@Override
public boolean isTerminated() {
return terminationLatch.getCount() == 0;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return terminationLatch.await(timeout, unit);
// default for testing
int numberOfRateLimitGroups() {
return rateLimitGroupings.size();
}
/**
* Execute the request at some point in the future.
*
* @param requestCreator the http request to send
* @param requestManager the http request to send
* @param inferenceInputs the inputs to send in the request
* @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the
* listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}.
@ -265,13 +275,13 @@ class RequestExecutorService implements RequestExecutor {
* @param listener an {@link ActionListener<InferenceServiceResults>} for the response or failure
*/
public void execute(
RequestManager requestCreator,
RequestManager requestManager,
InferenceInputs inferenceInputs,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
var task = new RequestTask(
requestCreator,
requestManager,
inferenceInputs,
timeout,
threadPool,
@ -280,13 +290,157 @@ class RequestExecutorService implements RequestExecutor {
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext())
);
completeExecution(task);
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> {
var endpointHandler = new RateLimitingEndpointHandler(
Integer.toString(requestManager.rateLimitGrouping().hashCode()),
queueCreator,
settings,
requestSender,
clock,
requestManager.rateLimitSettings(),
this::isShutdown,
rateLimiterCreator
);
endpointHandler.init();
return endpointHandler;
});
endpoint.enqueue(task);
}
private void completeExecution(RequestTask task) {
/**
* Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests.
* This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent
* as soon as a thread is available.
*/
private static class RateLimitingEndpointHandler {
private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE;
private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO;
private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class);
private static final int ACCUMULATED_TOKENS_LIMIT = 1;
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
private final Supplier<Boolean> isShutdownMethod;
private final RequestSender requestSender;
private final String id;
private final AtomicReference<Instant> timeOfLastEnqueue = new AtomicReference<>();
private final Clock clock;
private final RateLimiter rateLimiter;
private final RequestExecutorServiceSettings requestExecutorServiceSettings;
RateLimitingEndpointHandler(
String id,
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
RequestExecutorServiceSettings settings,
RequestSender requestSender,
Clock clock,
RateLimitSettings rateLimitSettings,
Supplier<Boolean> isShutdownMethod,
RateLimiterCreator rateLimiterCreator
) {
this.requestExecutorServiceSettings = Objects.requireNonNull(settings);
this.id = Objects.requireNonNull(id);
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
this.requestSender = Objects.requireNonNull(requestSender);
this.clock = Objects.requireNonNull(clock);
this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod);
Objects.requireNonNull(rateLimitSettings);
Objects.requireNonNull(rateLimiterCreator);
rateLimiter = rateLimiterCreator.create(
ACCUMULATED_TOKENS_LIMIT,
rateLimitSettings.requestsPerTimeUnit(),
rateLimitSettings.timeUnit()
);
}
public void init() {
requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange);
}
private void onCapacityChange(int capacity) {
logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity));
try {
queue.setCapacity(capacity);
} catch (Exception e) {
logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e);
}
}
public int queueSize() {
return queue.size();
}
public boolean isShutdown() {
return isShutdownMethod.get();
}
public Instant timeOfLastEnqueue() {
return timeOfLastEnqueue.get();
}
public synchronized TimeValue executeEnqueuedTask() {
try {
return executeEnqueuedTaskInternal();
} catch (Exception e) {
logger.warn(format("Executor service grouping [%s] failed to execute request", id), e);
// we tried to do some work but failed, so we'll say we did something to try looking for more work
return EXECUTED_A_TASK;
}
}
private TimeValue executeEnqueuedTaskInternal() {
var timeBeforeAvailableToken = rateLimiter.timeToReserve(1);
if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) {
return timeBeforeAvailableToken;
}
var task = queue.poll();
// TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks
// So we'll need to check for null and call a helper method executePreparedTasks()
if (shouldExecuteTask(task) == false) {
return NO_TASKS_AVAILABLE;
}
// We should never have to wait because we checked above
var reserveRes = rateLimiter.reserve(1);
assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have";
task.getRequestManager()
.execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener());
return EXECUTED_A_TASK;
}
private static boolean shouldExecuteTask(RejectableTask task) {
return task != null && isNoopRequest(task) == false && task.hasCompleted() == false;
}
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
return inferenceRequest.getRequestManager() == null
|| inferenceRequest.getInput() == null
|| inferenceRequest.getListener() == null;
}
private static boolean shouldExecuteImmediately(TimeValue delay) {
return delay.duration() == 0;
}
public void enqueue(RequestTask task) {
timeOfLastEnqueue.set(Instant.now(clock));
if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName),
format(
"Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown",
task.getRequestManager().inferenceEntityId(),
id
),
true
);
@ -294,24 +448,72 @@ class RequestExecutorService implements RequestExecutor {
return;
}
boolean added = queue.offer(task);
if (added == false) {
var addedToQueue = queue.offer(task);
if (addedToQueue == false) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
format("Failed to execute task because the http executor service [%s] queue is full", serviceName),
format(
"Failed to execute task for inference id [%s] because the request service [%s] queue is full",
task.getRequestManager().inferenceEntityId(),
id
),
false
);
task.onRejection(rejected);
} else if (isShutdown()) {
// It is possible that a shutdown and notification request occurred after we initially checked for shutdown above
// If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if
// we shut down and if so we'll redo the notification
notifyRequestsOfShutdown();
}
}
// default for testing
int remainingQueueCapacity() {
public synchronized void notifyRequestsOfShutdown() {
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
try {
List<RejectableTask> notExecuted = new ArrayList<>();
queue.drainTo(notExecuted);
rejectTasks(notExecuted);
} catch (Exception e) {
logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id));
}
}
private void rejectTasks(List<RejectableTask> tasks) {
for (var task : tasks) {
rejectTaskForShutdown(task);
}
}
private void rejectTaskForShutdown(RejectableTask task) {
try {
task.onRejection(
new EsRejectedExecutionException(
format(
"Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request",
id,
task.getRequestManager().inferenceEntityId()
),
true
)
);
} catch (Exception e) {
logger.warn(
format(
"Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown",
task.getRequestManager().inferenceEntityId(),
id
)
);
}
}
public int remainingCapacity() {
return queue.remainingCapacity();
}
public void close() {
requestExecutorServiceSettings.deregisterQueueCapacityCallback(id);
}
}
}

View file

@ -10,9 +10,12 @@ package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import java.util.ArrayList;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
public class RequestExecutorServiceSettings {
@ -29,37 +32,108 @@ public class RequestExecutorServiceSettings {
Setting.Property.Dynamic
);
private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50);
/**
* Defines how often all the rate limit groups are polled for tasks. Setting this to very low number could result
* in a busy loop if there are no tasks available to handle.
*/
static final Setting<TimeValue> TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting(
"xpack.inference.http.request_executor.task_poll_frequency",
DEFAULT_TASK_POLL_FREQUENCY_TIME,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
/**
* Defines how often a thread will check for rate limit groups that are stale.
*/
static final Setting<TimeValue> RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING = Setting.timeSetting(
"xpack.inference.http.request_executor.rate_limit_group_cleanup_interval",
DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION = TimeValue.timeValueDays(10);
/**
* Defines the amount of time it takes to classify a rate limit group as stale. Once it is classified as stale,
* it can be removed when the cleanup thread executes.
*/
static final Setting<TimeValue> RATE_LIMIT_GROUP_STALE_DURATION_SETTING = Setting.timeSetting(
"xpack.inference.http.request_executor.rate_limit_group_stale_duration",
DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static List<Setting<?>> getSettingsDefinitions() {
return List.of(TASK_QUEUE_CAPACITY_SETTING);
return List.of(
TASK_QUEUE_CAPACITY_SETTING,
TASK_POLL_FREQUENCY_SETTING,
RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING,
RATE_LIMIT_GROUP_STALE_DURATION_SETTING
);
}
private volatile int queueCapacity;
private final List<Consumer<Integer>> queueCapacityCallbacks = new ArrayList<Consumer<Integer>>();
private volatile TimeValue taskPollFrequency;
private volatile Duration rateLimitGroupStaleDuration;
private final ConcurrentMap<String, Consumer<Integer>> queueCapacityCallbacks = new ConcurrentHashMap<>();
public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) {
queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings);
taskPollFrequency = TASK_POLL_FREQUENCY_SETTING.get(settings);
setRateLimitGroupStaleDuration(RATE_LIMIT_GROUP_STALE_DURATION_SETTING.get(settings));
addSettingsUpdateConsumers(clusterService);
}
private void addSettingsUpdateConsumers(ClusterService clusterService) {
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity);
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency);
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(RATE_LIMIT_GROUP_STALE_DURATION_SETTING, this::setRateLimitGroupStaleDuration);
}
// default for testing
void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
for (var callback : queueCapacityCallbacks) {
for (var callback : queueCapacityCallbacks.values()) {
callback.accept(queueCapacity);
}
}
void registerQueueCapacityCallback(Consumer<Integer> onChangeCapacityCallback) {
queueCapacityCallbacks.add(onChangeCapacityCallback);
private void setTaskPollFrequency(TimeValue taskPollFrequency) {
this.taskPollFrequency = taskPollFrequency;
}
private void setRateLimitGroupStaleDuration(TimeValue staleDuration) {
rateLimitGroupStaleDuration = toDuration(staleDuration);
}
private static Duration toDuration(TimeValue timeValue) {
return Duration.of(timeValue.duration(), timeValue.timeUnit().toChronoUnit());
}
void registerQueueCapacityCallback(String id, Consumer<Integer> onChangeCapacityCallback) {
queueCapacityCallbacks.put(id, onChangeCapacityCallback);
}
void deregisterQueueCapacityCallback(String id) {
queueCapacityCallbacks.remove(id);
}
int getQueueCapacity() {
return queueCapacity;
}
TimeValue getTaskPollFrequency() {
return taskPollFrequency;
}
Duration getRateLimitGroupStaleDuration() {
return rateLimitGroupStaleDuration;
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
@ -21,14 +20,17 @@ import java.util.function.Supplier;
* A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service.
*/
public interface RequestManager extends RateLimitable {
Runnable create(
void execute(
@Nullable String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
);
// TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual
// managers to implement their own batching
// executePreparedRequest() which will execute all prepared requests aka sends the batch
String inferenceEntityId();
}

View file

@ -111,7 +111,7 @@ class RequestTask implements RejectableTask {
}
@Override
public RequestManager getRequestCreator() {
public RequestManager getRequestManager() {
return requestCreator;
}
}

View file

@ -1,48 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import java.util.Objects;
/**
* Handles executing a single inference request at a time.
*/
public class SingleRequestManager {
protected RetryingHttpSender requestSender;
public SingleRequestManager(RetryingHttpSender requestSender) {
this.requestSender = Objects.requireNonNull(requestSender);
}
public void execute(InferenceRequest inferenceRequest, HttpClientContext context) {
if (isNoopRequest(inferenceRequest) || inferenceRequest.hasCompleted()) {
return;
}
inferenceRequest.getRequestCreator()
.create(
inferenceRequest.getQuery(),
inferenceRequest.getInput(),
requestSender,
inferenceRequest.getRequestCompletedFunction(),
context,
inferenceRequest.getListener()
)
.run();
}
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
return inferenceRequest.getRequestCreator() == null
|| inferenceRequest.getInput() == null
|| inferenceRequest.getListener() == null;
}
}

View file

@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService {
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
Objects.requireNonNull(factory);
sender = factory.createSender(name());
sender = factory.createSender();
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

View file

@ -56,7 +56,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completio
public class AzureAiStudioService extends SenderService {
private static final String NAME = "azureaistudio";
static final String NAME = "azureaistudio";
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);

View file

@ -44,7 +44,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
ConfigurationParseContext context
) {
String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
AzureAiStudioService.NAME,
context
);
AzureAiStudioEndpointType endpointType = extractRequiredEnum(
map,
ENDPOINT_TYPE_FIELD,
@ -118,13 +124,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
protected void addXContentFields(XContentBuilder builder, Params params) throws IOException {
this.addExposedXContentFields(builder, params);
rateLimitSettings.toXContent(builder, params);
}
protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException {
builder.field(TARGET_FIELD, this.target);
builder.field(PROVIDER_FIELD, this.provider);
builder.field(ENDPOINT_TYPE_FIELD, this.endpointType);
rateLimitSettings.toXContent(builder, params);
}
}

View file

@ -135,7 +135,15 @@ public class AzureOpenAiService extends SenderService {
);
}
case COMPLETION -> {
return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
return new AzureOpenAiCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);
}
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
@ -37,13 +38,14 @@ public class AzureOpenAiCompletionModel extends AzureOpenAiModel {
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings),
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context),
AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
AzureOpenAiSecretSettings.fromMap(secrets)
);

View file

@ -17,7 +17,9 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -55,10 +57,10 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map) {
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var settings = fromMap(map, validationException);
var settings = fromMap(map, validationException, context);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -69,12 +71,19 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap(
Map<String, Object> map,
ValidationException validationException
ValidationException validationException,
ConfigurationParseContext context
) {
String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
AzureOpenAiService.NAME,
context
);
return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings);
}
@ -137,7 +146,6 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
@ -148,6 +156,7 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
builder.field(RESOURCE_NAME, resourceName);
builder.field(DEPLOYMENT_ID, deploymentId);
builder.field(API_VERSION, apiVersion);
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -90,7 +91,13 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
AzureOpenAiService.NAME,
context
);
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
@ -245,8 +252,6 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
builder.endObject();
@ -268,6 +273,7 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
if (similarity != null) {
builder.field(SIMILARITY, similarity);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -51,6 +51,11 @@ import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFie
public class CohereService extends SenderService {
public static final String NAME = "cohere";
// TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to
// the Cohere*RequestManager via the CohereActionCreator class
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
// on every request
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
@ -131,7 +136,15 @@ public class CohereService extends SenderService {
context
);
case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
case COMPLETION -> new CohereCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}

View file

@ -58,7 +58,13 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
CohereService.NAME,
context
);
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -173,10 +179,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
}
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
return builder;
return toXContentFragmentOfExposedFields(builder, params);
}
@Override
@ -196,6 +199,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
if (modelId != null) {
builder.field(MODEL_ID, modelId);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -30,13 +31,14 @@ public class CohereCompletionModel extends CohereModel {
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
modelId,
taskType,
service,
CohereCompletionServiceSettings.fromMap(serviceSettings),
CohereCompletionServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE,
DefaultSecretSettings.fromMap(secrets)
);

View file

@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -39,12 +41,18 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
// 10K requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map) {
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
CohereService.NAME,
context
);
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
@ -94,7 +102,6 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
@ -127,6 +134,7 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
if (modelId != null) {
builder.field(MODEL_ID, modelId);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -108,7 +108,8 @@ public class GoogleAiStudioService extends SenderService {
NAME,
serviceSettings,
taskSettings,
secretSettings
secretSettings,
context
);
case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel(
inferenceEntityId,
@ -116,7 +117,8 @@ public class GoogleAiStudioService extends SenderService {
NAME,
serviceSettings,
taskSettings,
secretSettings
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -37,13 +38,14 @@ public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel {
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets
Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings),
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE,
DefaultSecretSettings.fromMap(secrets)
);

View file

@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -40,11 +42,17 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map) {
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
GoogleAiStudioService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -82,7 +90,6 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
@ -107,6 +114,7 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID, modelId);
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -37,13 +38,14 @@ public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel {
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets
Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings),
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE,
DefaultSecretSettings.fromMap(secrets)
);

View file

@ -18,7 +18,9 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -47,7 +49,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map) {
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -59,7 +61,13 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
);
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
GoogleAiStudioService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -134,7 +142,6 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
@ -174,6 +181,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
if (similarity != null) {
builder.field(SIMILARITY, similarity);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@ -62,7 +63,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
taskType,
serviceSettingsMap,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name())
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
);
throwIfNotEmptyMap(config, name());
@ -89,7 +91,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
taskType,
serviceSettingsMap,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name())
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
);
}
@ -97,7 +100,14 @@ public abstract class HuggingFaceBaseService extends SenderService {
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
return createModel(inferenceEntityId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(inferenceEntityId, name()));
return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
);
}
protected abstract HuggingFaceModel createModel(
@ -105,7 +115,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> secretSettings,
String failureMessage
String failureMessage,
ConfigurationParseContext context
);
@Override

View file

@ -16,6 +16,7 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
@ -36,11 +37,19 @@ public class HuggingFaceService extends HuggingFaceBaseService {
TaskType taskType,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -43,14 +44,20 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
// 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map) {
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var uri = extractUri(map, URL, validationException);
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
HuggingFaceService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -119,7 +126,6 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
}
@ -136,6 +142,7 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -24,13 +25,14 @@ public class HuggingFaceElserModel extends HuggingFaceModel {
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
HuggingFaceElserServiceSettings.fromMap(serviceSettings),
HuggingFaceElserServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets)
);
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
@ -38,10 +39,11 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
TaskType taskType,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}

View file

@ -15,7 +15,9 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -40,10 +42,16 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
// 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map) {
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var uri = extractUri(map, URL, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
HuggingFaceService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -93,7 +101,6 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
@ -103,6 +110,7 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(URL, uri.toString());
builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT);
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -25,13 +26,14 @@ public class HuggingFaceEmbeddingsModel extends HuggingFaceModel {
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
HuggingFaceServiceSettings.fromMap(serviceSettings),
HuggingFaceServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets)
);
}

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -59,7 +60,13 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
MistralService.NAME,
context
);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
if (validationException.validationErrors().isEmpty() == false) {
@ -141,7 +148,6 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
this.toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
}
@ -159,6 +165,7 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
if (this.maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, this.maxInputTokens);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -138,7 +138,8 @@ public class OpenAiService extends SenderService {
NAME,
serviceSettings,
taskSettings,
secretSettings
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};

View file

@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.openai.OpenAiModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -35,13 +36,14 @@ public class OpenAiChatCompletionModel extends OpenAiModel {
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings),
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);

View file

@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -47,7 +49,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
// 500 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map) {
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -58,7 +60,13 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
OpenAiService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -142,7 +150,6 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
@ -163,6 +170,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -20,6 +20,7 @@ import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -66,7 +67,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
// passed at that time and never throw.
ValidationException validationException = new ValidationException();
var commonFields = fromMap(map, validationException);
var commonFields = fromMap(map, validationException, ConfigurationParseContext.PERSISTENT);
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
if (dimensionsSetByUser == null) {
@ -80,7 +81,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
var commonFields = fromMap(map, validationException);
var commonFields = fromMap(map, validationException, ConfigurationParseContext.REQUEST);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
@ -89,7 +90,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null);
}
private static CommonFields fromMap(Map<String, Object> map, ValidationException validationException) {
private static CommonFields fromMap(
Map<String, Object> map,
ValidationException validationException,
ConfigurationParseContext context
) {
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -98,7 +103,13 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
OpenAiService.NAME,
context
);
return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings);
}
@ -258,7 +269,6 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
if (dimensionsSetByUser != null) {
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
@ -286,6 +296,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}
rateLimitSettings.toXContent(builder, params);
return builder;
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import java.io.IOException;
import java.util.Map;
@ -21,19 +22,29 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
public class RateLimitSettings implements Writeable, ToXContentFragment {
public static final String FIELD_NAME = "rate_limit";
public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute";
private final long requestsPerTimeUnit;
private final TimeUnit timeUnit;
public static RateLimitSettings of(Map<String, Object> map, RateLimitSettings defaultValue, ValidationException validationException) {
public static RateLimitSettings of(
Map<String, Object> map,
RateLimitSettings defaultValue,
ValidationException validationException,
String serviceName,
ConfigurationParseContext context
) {
Map<String, Object> settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME);
var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
if (ConfigurationParseContext.isRequestContext(context)) {
throwIfNotEmptyMap(settings, serviceName);
}
return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute);
}

View file

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
@ -92,7 +93,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
TruncatorTests.createTruncator()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
@ -141,7 +142,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
TruncatorTests.createTruncator()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));

View file

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
@ -82,7 +83,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -132,7 +133,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -183,7 +184,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
// timeout as zero for no retries
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -237,7 +238,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
// note - there is no complete documentation on Azure's error messages
@ -313,7 +314,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
// note - there is no complete documentation on Azure's error messages
@ -389,7 +390,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testExecute_TruncatesInputBeforeSending() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -440,7 +441,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -498,7 +499,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -554,7 +555,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
// timeout as zero for no retries
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
// "choices" missing

View file

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
import static org.hamcrest.Matchers.hasSize;
@ -77,7 +78,7 @@ public class AzureOpenAiCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
@ -81,7 +82,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
mockClusterServiceEmpty()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@ -73,7 +74,7 @@ public class CohereActionCreatorTests extends ESTestCase {
public void testCreate_CohereEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -154,7 +155,7 @@ public class CohereActionCreatorTests extends ESTestCase {
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -214,7 +215,7 @@ public class CohereActionCreatorTests extends ESTestCase {
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -77,7 +77,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -138,7 +138,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -290,7 +290,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -81,7 +81,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -162,7 +162,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -74,7 +74,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -206,7 +206,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -79,7 +79,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
var input = "input";
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = senderFactory.createSender()) {
sender.start();
String responseJson = """

View file

@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.contains;
@ -75,7 +76,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -131,7 +132,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -187,7 +188,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -239,7 +240,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
// this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]]
@ -292,7 +293,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJsonContentTooLarge = """
@ -357,7 +358,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_TruncatesInputBeforeSending() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
@ -74,7 +75,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -127,7 +128,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -179,7 +180,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -238,7 +239,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -292,7 +293,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiChatCompletionModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -355,7 +356,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -417,7 +418,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -486,7 +487,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -552,7 +553,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
var contentTooLargeErrorMessage =
@ -635,7 +636,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
var contentTooLargeErrorMessage =
@ -718,7 +719,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testExecute_TruncatesInputBeforeSending() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@ -80,7 +81,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -234,7 +235,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """

View file

@ -79,7 +79,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
mockClusterServiceEmpty()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = senderFactory.createSender()) {
sender.start();
String responseJson = """

View file

@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.mock;
public class BaseRequestManagerTests extends ESTestCase {
public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() {
int val1 = 1;
int val2 = 1;
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
assertThat(manager1.rateLimitGrouping(), is(manager2.rateLimitGrouping()));
}
public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() {
int val1 = 1;
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
}
public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() {
int val1 = 1;
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
}
}

View file

@ -79,7 +79,7 @@ public class HttpRequestSenderTests extends ESTestCase {
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = createSenderFactory(clientManager, threadRef);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
@ -135,11 +135,11 @@ public class HttpRequestSenderTests extends ESTestCase {
mockClusterServiceEmpty()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = senderFactory.createSender()) {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var thrownException = expectThrows(
AssertionError.class,
() -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
() -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
);
assertThat(thrownException.getMessage(), is("call start() before sending a request"));
}
@ -155,17 +155,12 @@ public class HttpRequestSenderTests extends ESTestCase {
mockClusterServiceEmpty()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = senderFactory.createSender()) {
assertThat(sender, instanceOf(HttpRequestSender.class));
sender.start();
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.send(
ExecutableRequestCreatorTests.createMock(),
new DocumentsOnlyInput(List.of()),
TimeValue.timeValueNanos(1),
listener
);
sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
@ -186,16 +181,11 @@ public class HttpRequestSenderTests extends ESTestCase {
mockClusterServiceEmpty()
);
try (var sender = senderFactory.createSender("test_service")) {
try (var sender = senderFactory.createSender()) {
sender.start();
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.send(
ExecutableRequestCreatorTests.createMock(),
new DocumentsOnlyInput(List.of()),
TimeValue.timeValueNanos(1),
listener
);
sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
@ -220,6 +210,7 @@ public class HttpRequestSenderTests extends ESTestCase {
when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService);
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class));
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
return new HttpRequestSender.Factory(
ServiceComponentsTests.createWithEmptySettings(mockThreadPool),
@ -248,7 +239,7 @@ public class HttpRequestSenderTests extends ESTestCase {
);
}
public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) {
return factory.createSender(serviceName);
public static Sender createSender(HttpRequestSender.Factory factory) {
return factory.createSender();
}
}

View file

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
@ -18,12 +19,23 @@ public class RequestExecutorServiceSettingsTests {
}
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) {
return createRequestExecutorServiceSettings(queueCapacity, null);
}
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(
@Nullable Integer queueCapacity,
@Nullable TimeValue staleDuration
) {
var settingsBuilder = Settings.builder();
if (queueCapacity != null) {
settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity);
}
if (staleDuration != null) {
settingsBuilder.put(RequestExecutorServiceSettings.RATE_LIMIT_GROUP_STALE_DURATION_SETTING.getKey(), staleDuration);
}
return createRequestExecutorServiceSettings(settingsBuilder.build());
}

View file

@ -18,13 +18,19 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.RateLimiter;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
@ -42,10 +48,13 @@ import static org.elasticsearch.xpack.inference.external.http.sender.RequestExec
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
public class RequestExecutorServiceTests extends ESTestCase {
@ -70,7 +79,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
public void testQueueSize_IsOne() {
var service = createRequestExecutorServiceWithMocks();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
assertThat(service.queueSize(), is(1));
}
@ -92,7 +101,20 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertTrue(service.isTerminated());
}
public void testIsTerminated_AfterStopFromSeparateThread() throws Exception {
public void testCallingStartTwice_ThrowsAssertionException() throws InterruptedException {
var latch = new CountDownLatch(1);
var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class));
service.shutdown();
service.start();
latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
assertTrue(service.isTerminated());
var exception = expectThrows(AssertionError.class, service::start);
assertThat(exception.getMessage(), is("start() can only be called once"));
}
public void testIsTerminated_AfterStopFromSeparateThread() {
var waitToShutdown = new CountDownLatch(1);
var waitToReturnFromSend = new CountDownLatch(1);
@ -127,41 +149,48 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertTrue(service.isTerminated());
}
public void testSend_AfterShutdown_Throws() {
public void testExecute_AfterShutdown_Throws() {
var service = createRequestExecutorServiceWithMocks();
service.shutdown();
var requestManager = RequestManagerTests.createMock("id");
var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to enqueue task because the http executor service [test_service] has already shutdown")
is(
Strings.format(
"Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown",
requestManager.rateLimitGrouping().hashCode()
)
)
);
assertTrue(thrownException.isExecutorShutdown());
}
public void testSend_Throws_WhenQueueIsFull() {
var service = new RequestExecutorService(
"test_service",
threadPool,
null,
createRequestExecutorServiceSettings(1),
new SingleRequestManager(mock(RetryingHttpSender.class))
);
public void testExecute_Throws_WhenQueueIsFull() {
var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class));
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
var requestManager = RequestManagerTests.createMock("id");
var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full")
is(
Strings.format(
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
requestManager.rateLimitGrouping().hashCode()
)
)
);
assertFalse(thrownException.isExecutorShutdown());
}
@ -203,16 +232,11 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertTrue(service.isShutdown());
}
public void testSend_CallsOnFailure_WhenRequestTimesOut() {
public void testExecute_CallsOnFailure_WhenRequestTimesOut() {
var service = createRequestExecutorServiceWithMocks();
var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(
ExecutableRequestCreatorTests.createMock(),
new DocumentsOnlyInput(List.of()),
TimeValue.timeValueNanos(1),
listener
);
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
@ -222,7 +246,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
);
}
public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
public void testExecute_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
var headerKey = "not empty";
var headerValue = "value";
@ -270,7 +294,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
}
};
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
Future<?> executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service);
@ -280,11 +304,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
}
public void testSend_NotifiesTasksOfShutdown() {
public void testExecute_NotifiesTasksOfShutdown() {
var service = createRequestExecutorServiceWithMocks();
var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id");
var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
service.shutdown();
service.start();
@ -293,47 +318,62 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertThat(
thrownException.getMessage(),
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
is(
Strings.format(
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
requestManager.rateLimitGrouping().hashCode()
)
)
);
assertTrue(thrownException.isExecutorShutdown());
assertTrue(service.isTerminated());
}
public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
@SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
var requestSender = mock(RetryingHttpSender.class);
var service = new RequestExecutorService(
getTestName(),
threadPool,
mockQueueCreator(queue),
null,
createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(mock(RetryingHttpSender.class))
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var requestManager = RequestManagerTests.createMock(requestSender, "id");
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
service.shutdown();
return null;
});
service.start();
assertTrue(service.isTerminated());
verify(queue, times(2)).take();
}
public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception {
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
@SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
when(queue.take()).thenThrow(new InterruptedException("failed"));
var sleeper = mock(RequestExecutorService.Sleeper.class);
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
var service = new RequestExecutorService(
getTestName(),
threadPool,
mockQueueCreator(queue),
null,
createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(mock(RetryingHttpSender.class))
mock(RetryingHttpSender.class),
Clock.systemUTC(),
sleeper,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
Future<?> executorTermination = threadPool.generic().submit(() -> {
@ -347,66 +387,30 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated());
verify(queue, times(1)).take();
}
public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception {
var mockTask = mock(RejectableTask.class);
@SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
var service = new RequestExecutorService(
"test_service",
threadPool,
mockQueueCreator(queue),
null,
createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(mock(RetryingHttpSender.class))
);
doAnswer(invocation -> {
service.shutdown();
return mockTask;
}).doReturn(new NoopTask()).when(queue).take();
service.start();
assertTrue(service.isTerminated());
verify(queue, times(1)).take();
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
verify(mockTask, times(1)).onRejection(argument.capture());
assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class));
assertThat(
argument.getValue().getMessage(),
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
);
var rejectionException = (EsRejectedExecutionException) argument.getValue();
assertTrue(rejectionException.isExecutorShutdown());
}
public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException {
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
service.execute(
ExecutableRequestCreatorTests.createMock(requestSender),
new DocumentsOnlyInput(List.of()),
null,
new PlainActionFuture<>()
);
service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
assertThat(service.queueSize(), is(1));
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
var requestManager = RequestManagerTests.createMock(requestSender, "id");
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full")
is(
Strings.format(
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
requestManager.rateLimitGrouping().hashCode()
)
)
);
settings.setQueueCapacity(2);
@ -426,7 +430,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated());
assertThat(service.remainingQueueCapacity(), is(2));
assertThat(service.remainingQueueCapacity(requestManager), is(2));
}
public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException,
@ -434,23 +438,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(3);
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
service.execute(
ExecutableRequestCreatorTests.createMock(requestSender),
RequestManagerTests.createMock(requestSender, "id"),
new DocumentsOnlyInput(List.of()),
null,
new PlainActionFuture<>()
);
service.execute(
ExecutableRequestCreatorTests.createMock(requestSender),
RequestManagerTests.createMock(requestSender, "id"),
new DocumentsOnlyInput(List.of()),
null,
new PlainActionFuture<>()
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
var requestManager = RequestManagerTests.createMock(requestSender, "id");
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
assertThat(service.queueSize(), is(3));
settings.setQueueCapacity(1);
@ -470,7 +475,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated());
assertThat(service.remainingQueueCapacity(), is(1));
assertThat(service.remainingQueueCapacity(requestManager), is(1));
assertThat(service.queueSize(), is(0));
var thrownException = expectThrows(
@ -479,7 +484,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
);
assertThat(
thrownException.getMessage(),
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
is(
Strings.format(
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
requestManager.rateLimitGrouping().hashCode()
)
)
);
assertTrue(thrownException.isExecutorShutdown());
}
@ -489,23 +499,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
var requestManager = RequestManagerTests.createMock(requestSender);
service.execute(
ExecutableRequestCreatorTests.createMock(requestSender),
new DocumentsOnlyInput(List.of()),
null,
new PlainActionFuture<>()
);
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
assertThat(service.queueSize(), is(1));
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
service.execute(RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full")
is(
Strings.format(
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
requestManager.rateLimitGrouping().hashCode()
)
)
);
settings.setQueueCapacity(0);
@ -525,7 +536,133 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated());
assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE));
assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE));
}
public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
var mockRateLimiter = mock(RateLimiter.class);
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService(
threadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
doAnswer(invocation -> {
service.shutdown();
return TimeValue.timeValueDays(1);
}).when(mockRateLimiter).timeToReserve(anyInt());
service.start();
verifyNoInteractions(requestSender);
}
public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() {
var mockRateLimiter = mock(RateLimiter.class);
when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0));
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService(
threadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0));
doAnswer(invocation -> {
service.shutdown();
return Void.TYPE;
}).when(requestSender).send(any(), any(), any(), any(), any(), any());
service.start();
verify(requestSender, times(1)).send(any(), any(), any(), any(), any(), any());
}
public void testRemovesRateLimitGroup_AfterStaleDuration() {
var now = Instant.now();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
var service = new RequestExecutorService(
threadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
clock,
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
assertThat(service.numberOfRateLimitGroups(), is(1));
// the time is moved to after the stale duration, so now we should remove this grouping
when(clock.instant()).thenReturn(now.plus(Duration.ofDays(2)));
service.removeStaleGroupings();
assertThat(service.numberOfRateLimitGroups(), is(0));
var requestManager2 = RequestManagerTests.createMock(requestSender, "id2");
service.execute(requestManager2, new DocumentsOnlyInput(List.of()), null, listener);
assertThat(service.numberOfRateLimitGroups(), is(1));
}
public void testStartsCleanupThread() {
var mockThreadPool = mock(ThreadPool.class);
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
var service = new RequestExecutorService(
mockThreadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
service.shutdown();
service.start();
ArgumentCaptor<TimeValue> argument = ArgumentCaptor.forClass(TimeValue.class);
verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), argument.capture(), any());
assertThat(argument.getValue(), is(TimeValue.timeValueDays(1)));
}
private Future<?> submitShutdownRequest(
@ -552,12 +689,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
}
private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) {
return new RequestExecutorService(
"test_service",
threadPool,
startupLatch,
createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(requestSender)
);
return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender);
}
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.RequestTests;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
@ -21,34 +22,47 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ExecutableRequestCreatorTests {
public class RequestManagerTests {
public static RequestManager createMock() {
var mockCreator = mock(RequestManager.class);
when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {});
return createMock(mock(RequestSender.class));
}
return mockCreator;
public static RequestManager createMock(String inferenceEntityId) {
return createMock(mock(RequestSender.class), inferenceEntityId);
}
public static RequestManager createMock(RequestSender requestSender) {
return createMock(requestSender, "id");
return createMock(requestSender, "id", new RateLimitSettings(1));
}
public static RequestManager createMock(RequestSender requestSender, String modelId) {
var mockCreator = mock(RequestManager.class);
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) {
return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1));
}
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) {
var mockManager = mock(RequestManager.class);
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[5];
return (Runnable) () -> requestSender.send(
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[4];
requestSender.send(
mock(Logger.class),
RequestTests.mockRequest(modelId),
RequestTests.mockRequest(inferenceEntityId),
HttpClientContext.create(),
() -> false,
mock(ResponseHandler.class),
listener
);
}).when(mockCreator).create(any(), anyList(), any(), any(), any(), any());
return mockCreator;
return Void.TYPE;
}).when(mockManager).execute(any(), anyList(), any(), any(), any());
// just return something consistent so the hashing works
when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId);
when(mockManager.rateLimitSettings()).thenReturn(settings);
when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId);
return mockManager;
}
}

View file

@ -1,27 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
public class SingleRequestManagerTests extends ESTestCase {
public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() {
var requestCreator = mock(RequestManager.class);
var request = mock(InferenceRequest.class);
when(request.getRequestCreator()).thenReturn(requestCreator);
new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create());
verifyNoInteractions(requestCreator);
}
}

View file

@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -59,7 +58,7 @@ public class SenderServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
@ -67,7 +66,7 @@ public class SenderServiceTests extends ESTestCase {
listener.actionGet(TIMEOUT);
verify(sender, times(1)).start();
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
}
verify(sender, times(1)).close();
@ -79,7 +78,7 @@ public class SenderServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
@ -89,7 +88,7 @@ public class SenderServiceTests extends ESTestCase {
service.start(mock(Model.class), listener);
listener.actionGet(TIMEOUT);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(2)).start();
}

View file

@ -76,7 +76,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -819,7 +818,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -841,7 +840,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}

View file

@ -112,7 +112,8 @@ public class AzureAiStudioChatCompletionServiceSettingsTests extends ESTestCase
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, CoreMatchers.is("""
{"target":"target_value","provider":"openai","endpoint_type":"token"}"""));
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
"rate_limit":{"requests_per_minute":3}}"""));
}
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {

View file

@ -295,7 +295,7 @@ public class AzureAiStudioEmbeddingsServiceSettingsTests extends ESTestCase {
assertThat(xContentResult, CoreMatchers.is("""
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
"dimensions":1024,"max_input_tokens":512}"""));
"rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}"""));
}
public static HashMap<String, Object> createRequestSettingsMap(

View file

@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -594,7 +593,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -616,7 +615,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
import java.io.IOException;
@ -46,7 +47,8 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
AzureOpenAiServiceFields.API_VERSION,
apiVersion
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null)));
@ -63,18 +65,6 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}"""));
}
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}"""));
}
@Override
protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() {
return AzureOpenAiCompletionServiceSettings::new;

View file

@ -389,7 +389,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}"""));
}
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_RateLimit() throws IOException {
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser() throws IOException {
var entity = new AzureOpenAiEmbeddingsServiceSettings(
"resource",
"deployment",
@ -408,7 +408,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
assertThat(xContentResult, is("""
{"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
"dimensions":1024,"max_input_tokens":512}"""));
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":1}}"""));
}
@Override

View file

@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -613,7 +612,7 @@ public class CohereServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -635,7 +634,7 @@ public class CohereServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}

View file

@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.HashMap;
@ -28,7 +29,8 @@ public class CohereCompletionModelTests extends ESTestCase {
"service",
new HashMap<>(Map.of()),
new HashMap<>(Map.of("model", "overridden model")),
null
null,
ConfigurationParseContext.PERSISTENT
);
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));

View file

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
@ -34,7 +35,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
var model = "model";
var serviceSettings = CohereCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model))
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null)));
@ -55,7 +57,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute))));
@ -72,18 +75,6 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
{"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}"""));
}
public void testToXContent_WithFilteredObject_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url","model_id":"model"}"""));
}
@Override
protected Writeable.Reader<CohereCompletionServiceSettings> instanceReader() {
return CohereCompletionServiceSettings::new;

View file

@ -331,21 +331,6 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
"rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}"""));
}
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new CohereEmbeddingsServiceSettings(
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)),
CohereEmbeddingType.INT8
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """
"embedding_type":"byte"}"""));
}
@Override
protected Writeable.Reader<CohereEmbeddingsServiceSettings> instanceReader() {
return CohereEmbeddingsServiceSettings::new;

View file

@ -51,20 +51,6 @@ public class CohereRerankServiceSettingsTests extends AbstractWireSerializingTes
"rate_limit":{"requests_per_minute":3}}"""));
}
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new CohereRerankServiceSettings(
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3))
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
// TODO we probably shouldn't allow configuring these fields for reranking
assertThat(xContentResult, is("""
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id"}"""));
}
@Override
protected Writeable.Reader<CohereRerankServiceSettings> instanceReader() {
return CohereRerankServiceSettings::new;

View file

@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -494,7 +493,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -516,7 +515,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}

View file

@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.net.URISyntaxException;
@ -28,7 +29,8 @@ public class GoogleAiStudioCompletionModelTests extends ESTestCase {
"service",
new HashMap<>(Map.of("model_id", "model")),
new HashMap<>(Map.of()),
null
null,
ConfigurationParseContext.PERSISTENT
);
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));

View file

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
@ -31,7 +32,10 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
public void testFromMap_Request_CreatesSettingsCorrectly() {
var model = "some model";
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)));
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null)));
}
@ -47,18 +51,6 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
{"model_id":"model","rate_limit":{"requests_per_minute":360}}"""));
}
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
var entity = new GoogleAiStudioCompletionServiceSettings("model", null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"model_id":"model"}"""));
}
@Override
protected Writeable.Reader<GoogleAiStudioCompletionServiceSettings> instanceReader() {
return GoogleAiStudioCompletionServiceSettings::new;

View file

@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
@ -55,7 +56,8 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
ServiceFields.SIMILARITY,
similarity.toString()
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null)));
@ -80,23 +82,6 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
}"""));
}
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model_id":"model",
"max_input_tokens": 1024,
"dimensions": 8,
"similarity": "dot_product"
}"""));
}
@Override
protected Writeable.Reader<GoogleAiStudioEmbeddingsServiceSettings> instanceReader() {
return GoogleAiStudioEmbeddingsServiceSettings::new;

View file

@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.junit.After;
import org.junit.Before;
@ -33,7 +34,6 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -59,7 +59,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -81,7 +81,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}
@ -111,7 +111,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> secretSettings,
String failureMessage
String failureMessage,
ConfigurationParseContext context
) {
return null;
}

View file

@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -57,7 +58,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
var dims = 384;
var maxInputTokens = 128;
{
var serviceSettings = HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)));
var serviceSettings = HuggingFaceServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url)));
}
{
@ -73,7 +77,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
serviceSettings,
@ -95,7 +100,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
serviceSettings,
@ -105,7 +111,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
}
public void testFromMap_MissingUrl_ThrowsError() {
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceServiceSettings.fromMap(new HashMap<>()));
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
);
assertThat(
thrownException.getMessage(),
@ -118,7 +127,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
public void testFromMap_EmptyUrl_ThrowsError() {
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")))
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT)
);
assertThat(
@ -136,7 +145,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
var url = "https://www.abc^.com";
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)))
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT)
);
assertThat(
@ -152,7 +161,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
var similarity = "by_size";
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)))
() -> HuggingFaceServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
@ -175,18 +187,6 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
{"url":"url","rate_limit":{"requests_per_minute":3}}"""));
}
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new HuggingFaceServiceSettings(ServiceUtils.createUri("url"), null, null, null, new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url"}"""));
}
@Override
protected Writeable.Reader<HuggingFaceServiceSettings> instanceReader() {
return HuggingFaceServiceSettings::new;

View file

@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -32,7 +33,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
public void testFromMap() {
var url = "https://www.abc.com";
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)));
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
ConfigurationParseContext.PERSISTENT
);
assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings));
}
@ -40,7 +44,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
public void testFromMap_EmptyUrl_ThrowsError() {
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")))
() -> HuggingFaceElserServiceSettings.fromMap(
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
@ -55,7 +62,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
}
public void testFromMap_MissingUrl_ThrowsError() {
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>()));
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
);
assertThat(
thrownException.getMessage(),
@ -72,7 +82,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
var url = "https://www.abc^.com";
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)))
() -> HuggingFaceElserServiceSettings.fromMap(
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
@ -98,18 +111,6 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
{"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}"""));
}
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new HuggingFaceElserServiceSettings(ServiceUtils.createUri("url"), new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url","max_input_tokens":512}"""));
}
@Override
protected Writeable.Reader<HuggingFaceElserServiceSettings> instanceReader() {
return HuggingFaceElserServiceSettings::new;

View file

@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -393,7 +392,7 @@ public class MistralServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -415,7 +414,7 @@ public class MistralServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}

View file

@ -98,18 +98,6 @@ public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
"rate_limit":{"requests_per_minute":3}}"""));
}
public void testToFilteredXContent_WritesFilteredValues() throws IOException {
var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, CoreMatchers.is("""
{"model":"model_name","dimensions":1024,"max_input_tokens":512}"""));
}
public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
var outputBuffer = new BytesStreamOutput();
var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));

View file

@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -675,7 +674,7 @@ public class OpenAiServiceTests extends ESTestCase {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name");
@ -697,7 +696,7 @@ public class OpenAiServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender(anyString());
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
@ -48,7 +49,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
@ -77,7 +79,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit))
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
@ -101,7 +104,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens
)
)
),
ConfigurationParseContext.PERSISTENT
);
assertNull(serviceSettings.uri());
@ -113,7 +117,10 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
public void testFromMap_EmptyUrl_ThrowsError() {
var thrownException = expectThrows(
ValidationException.class,
() -> OpenAiChatCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")))
() -> OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
@ -132,7 +139,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
var maxInputTokens = 8192;
var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens))
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)),
ConfigurationParseContext.PERSISTENT
);
assertNull(serviceSettings.uri());
@ -144,7 +152,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
var thrownException = expectThrows(
ValidationException.class,
() -> OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model"))
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")),
ConfigurationParseContext.PERSISTENT
)
);
@ -164,7 +173,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
var thrownException = expectThrows(
ValidationException.class,
() -> OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model"))
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")),
ConfigurationParseContext.PERSISTENT
)
);
@ -213,19 +223,6 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
{"model_id":"model","rate_limit":{"requests_per_minute":500}}"""));
}
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new OpenAiChatCompletionServiceSettings("model", "url", "org", 1024, new RateLimitSettings(2));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
assertThat(xContentResult, is("""
{"model_id":"model","url":"url","organization_id":"org",""" + """
"max_input_tokens":1024}"""));
}
@Override
protected Writeable.Reader<OpenAiChatCompletionServiceSettings> instanceReader() {
return OpenAiChatCompletionServiceSettings::new;

Some files were not shown because too many files have changed in this diff Show more