[ML] Inference API rate limit queuing logic refactor (#107706)

* Adding new executor

* Adding in queuing logic

* working tests

* Added cleanup task

* Update docs/changelog/107706.yaml

* Updating yml

* deregistering callbacks for settings changes

* Cleaning up code

* Update docs/changelog/107706.yaml

* Fixing rate limit settings bug and only sleeping least amount

* Removing debug logging

* Removing commented code

* Renaming feedback

* fixing tests

* Updating docs and validation

* Fixing source blocks

* Adjusting cancel logic

* Reformatting ascii

* Addressing feedback

* adding rate limiting for google embeddings and mistral

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Jonathan Buttner 2024-06-05 08:25:25 -04:00 committed by GitHub
parent cd84749d87
commit fdb5058b13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
102 changed files with 1499 additions and 937 deletions

View file

@ -0,0 +1,5 @@
pr: 107706
summary: Add rate limiting support for the Inference API
area: Machine Learning
type: enhancement
issues: []

View file

@ -7,21 +7,17 @@ experimental[]
Creates an {infer} endpoint to perform an {infer} task. Creates an {infer} endpoint to perform an {infer} task.
IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face.
OpenAI, Google AI Studio or Hugging Face. For built-in models and models For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models.
uploaded though Eland, the {infer} APIs offer an alternative way to use and However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the
manage trained models. However, if you do not plan to use the {infer} APIs to
use these models or if you want to use non-NLP models, use the
<<ml-df-trained-models-apis>>. <<ml-df-trained-models-apis>>.
[discrete] [discrete]
[[put-inference-api-request]] [[put-inference-api-request]]
==== {api-request-title} ==== {api-request-title}
`PUT /_inference/<task_type>/<inference_id>` `PUT /_inference/<task_type>/<inference_id>`
[discrete] [discrete]
[[put-inference-api-prereqs]] [[put-inference-api-prereqs]]
==== {api-prereq-title} ==== {api-prereq-title}
@ -29,7 +25,6 @@ use these models or if you want to use non-NLP models, use the
* Requires the `manage_inference` <<privileges-list-cluster,cluster privilege>> * Requires the `manage_inference` <<privileges-list-cluster,cluster privilege>>
(the built-in `inference_admin` role grants this privilege) (the built-in `inference_admin` role grants this privilege)
[discrete] [discrete]
[[put-inference-api-desc]] [[put-inference-api-desc]]
==== {api-description-title} ==== {api-description-title}
@ -48,25 +43,23 @@ The following services are available through the {infer} API:
* Hugging Face * Hugging Face
* OpenAI * OpenAI
[discrete] [discrete]
[[put-inference-api-path-params]] [[put-inference-api-path-params]]
==== {api-path-parms-title} ==== {api-path-parms-title}
`<inference_id>`:: `<inference_id>`::
(Required, string) (Required, string)
The unique identifier of the {infer} endpoint. The unique identifier of the {infer} endpoint.
`<task_type>`:: `<task_type>`::
(Required, string) (Required, string)
The type of the {infer} task that the model will perform. Available task types: The type of the {infer} task that the model will perform.
Available task types:
* `completion`, * `completion`,
* `rerank`, * `rerank`,
* `sparse_embedding`, * `sparse_embedding`,
* `text_embedding`. * `text_embedding`.
[discrete] [discrete]
[[put-inference-api-request-body]] [[put-inference-api-request-body]]
==== {api-request-body-title} ==== {api-request-body-title}
@ -78,21 +71,18 @@ Available services:
* `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service. * `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service.
* `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service. * `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service.
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the * `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service.
Cohere service. * `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland.
* `elasticsearch`: specify the `text_embedding` task type to use the E5
built-in model or text embedding models uploaded by Eland.
* `elser`: specify the `sparse_embedding` task type to use the ELSER service. * `elser`: specify the `sparse_embedding` task type to use the ELSER service.
* `googleaistudio`: specify the `completion` task to use the Google AI Studio service. * `googleaistudio`: specify the `completion` task to use the Google AI Studio service.
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face * `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service.
service. * `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service.
* `openai`: specify the `completion` or `text_embedding` task type to use the
OpenAI service.
`service_settings`:: `service_settings`::
(Required, object) (Required, object)
Settings used to install the {infer} model. These settings are specific to the Settings used to install the {infer} model.
These settings are specific to the
`service` you specified. `service` you specified.
+ +
.`service_settings` for the `azureaistudio` service .`service_settings` for the `azureaistudio` service
@ -104,11 +94,10 @@ Settings used to install the {infer} model. These settings are specific to the
A valid API key of your Azure AI Studio model deployment. A valid API key of your Azure AI Studio model deployment.
This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account. This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account.
IMPORTANT: You need to provide the API key only once, during the {infer} model IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
creation. The <<get-inference-api>> does not retrieve your API key. After The <<get-inference-api>> does not retrieve your API key.
creating the {infer} model, you cannot change the associated API key. If you After creating the {infer} model, you cannot change the associated API key.
want to use a different API key, delete the {infer} model and recreate it with If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
the same name and the updated API key.
`target`::: `target`:::
(Required, string) (Required, string)
@ -142,11 +131,13 @@ For "real-time" endpoints which are billed per hour of usage, specify `realtime`
By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`. By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`.
This helps to minimize the number of rate limit errors returned from Azure AI Studio. This helps to minimize the number of rate limit errors returned from Azure AI Studio.
To modify this, set the `requests_per_minute` setting of this object in your service settings: To modify this, set the `requests_per_minute` setting of this object in your service settings:
``` +
[source,text]
----
"rate_limit": { "rate_limit": {
"requests_per_minute": <<number_of_requests>> "requests_per_minute": <<number_of_requests>>
} }
``` ----
===== =====
+ +
.`service_settings` for the `azureopenai` service .`service_settings` for the `azureopenai` service
@ -181,6 +172,22 @@ Your Azure OpenAI deployments can be found though the https://oai.azure.com/[Azu
The Azure API version ID to use. The Azure API version ID to use.
We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version]. We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version].
`rate_limit`:::
(Optional, object)
The `azureopenai` service sets a default number of requests allowed per minute depending on the task type.
For `text_embedding` it is set to `1440`.
For `completion` it is set to `120`.
This helps to minimize the number of rate limit errors returned from Azure.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
+
More information about the rate limits for Azure can be found in the https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits[Quota limits docs] and https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest[How to change the quotas].
===== =====
+ +
.`service_settings` for the `cohere` service .`service_settings` for the `cohere` service
@ -188,24 +195,24 @@ We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/opena
===== =====
`api_key`::: `api_key`:::
(Required, string) (Required, string)
A valid API key of your Cohere account. You can find your Cohere API keys or you A valid API key of your Cohere account.
can create a new one You can find your Cohere API keys or you can create a new one
https://dashboard.cohere.com/api-keys[on the API keys settings page]. https://dashboard.cohere.com/api-keys[on the API keys settings page].
IMPORTANT: You need to provide the API key only once, during the {infer} model IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
creation. The <<get-inference-api>> does not retrieve your API key. After The <<get-inference-api>> does not retrieve your API key.
creating the {infer} model, you cannot change the associated API key. If you After creating the {infer} model, you cannot change the associated API key.
want to use a different API key, delete the {infer} model and recreate it with If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
the same name and the updated API key.
`embedding_type`:: `embedding_type`::
(Optional, string) (Optional, string)
Only for `text_embedding`. Specifies the types of embeddings you want to get Only for `text_embedding`.
back. Defaults to `float`. Specifies the types of embeddings you want to get back.
Defaults to `float`.
Valid values are: Valid values are:
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). * `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
* `float`: use it for the default float embeddings. * `float`: use it for the default float embeddings.
* `int8`: use it for signed int8 embeddings. * `int8`: use it for signed int8 embeddings.
`model_id`:: `model_id`::
(Optional, string) (Optional, string)
@ -214,50 +221,68 @@ To review the available `rerank` models, refer to the
https://docs.cohere.com/reference/rerank-1[Cohere docs]. https://docs.cohere.com/reference/rerank-1[Cohere docs].
To review the available `text_embedding` models, refer to the To review the available `text_embedding` models, refer to the
https://docs.cohere.com/reference/embed[Cohere docs]. The default value for https://docs.cohere.com/reference/embed[Cohere docs].
The default value for
`text_embedding` is `embed-english-v2.0`. `text_embedding` is `embed-english-v2.0`.
`rate_limit`:::
(Optional, object)
By default, the `cohere` service sets the number of requests allowed per minute to `10000`.
This value is the same for all task types.
This helps to minimize the number of rate limit errors returned from Cohere.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
+
More information about Cohere's rate limits can be found in https://docs.cohere.com/docs/going-live#production-key-specifications[Cohere's production key docs].
===== =====
+ +
.`service_settings` for the `elasticsearch` service .`service_settings` for the `elasticsearch` service
[%collapsible%closed] [%collapsible%closed]
===== =====
`model_id`::: `model_id`:::
(Required, string) (Required, string)
The name of the model to use for the {infer} task. It can be the The name of the model to use for the {infer} task.
ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or It can be the ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or a text embedding model already
a text embedding model already
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
`num_allocations`::: `num_allocations`:::
(Required, integer) (Required, integer)
The number of model allocations to create. `num_allocations` must not exceed the The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
number of available processors per node divided by the `num_threads`.
`num_threads`::: `num_threads`:::
(Required, integer) (Required, integer)
The number of threads to use by each model allocation. `num_threads` must not The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
exceed the number of available processors per node divided by the number of Must be a power of 2. Max allowed value is 32.
allocations. Must be a power of 2. Max allowed value is 32.
===== =====
+ +
.`service_settings` for the `elser` service .`service_settings` for the `elser` service
[%collapsible%closed] [%collapsible%closed]
===== =====
`num_allocations`::: `num_allocations`:::
(Required, integer) (Required, integer)
The number of model allocations to create. `num_allocations` must not exceed the The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
number of available processors per node divided by the `num_threads`.
`num_threads`::: `num_threads`:::
(Required, integer) (Required, integer)
The number of threads to use by each model allocation. `num_threads` must not The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
exceed the number of available processors per node divided by the number of Must be a power of 2. Max allowed value is 32.
allocations. Must be a power of 2. Max allowed value is 32.
===== =====
+ +
.`service_settings` for the `googleiastudio` service .`service_settings` for the `googleiastudio` service
[%collapsible%closed] [%collapsible%closed]
===== =====
`api_key`::: `api_key`:::
(Required, string) (Required, string)
A valid API key for the Google Gemini API. A valid API key for the Google Gemini API.
@ -274,76 +299,113 @@ This helps to minimize the number of rate limit errors returned from Google AI S
To modify this, set the `requests_per_minute` setting of this object in your service settings: To modify this, set the `requests_per_minute` setting of this object in your service settings:
+ +
-- --
``` [source,text]
----
"rate_limit": { "rate_limit": {
"requests_per_minute": <<number_of_requests>> "requests_per_minute": <<number_of_requests>>
} }
``` ----
-- --
===== =====
+ +
.`service_settings` for the `hugging_face` service .`service_settings` for the `hugging_face` service
[%collapsible%closed] [%collapsible%closed]
===== =====
`api_key`::: `api_key`:::
(Required, string) (Required, string)
A valid access token of your Hugging Face account. You can find your Hugging A valid access token of your Hugging Face account.
Face access tokens or you can create a new one You can find your Hugging Face access tokens or you can create a new one
https://huggingface.co/settings/tokens[on the settings page]. https://huggingface.co/settings/tokens[on the settings page].
IMPORTANT: You need to provide the API key only once, during the {infer} model IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
creation. The <<get-inference-api>> does not retrieve your API key. After The <<get-inference-api>> does not retrieve your API key.
creating the {infer} model, you cannot change the associated API key. If you After creating the {infer} model, you cannot change the associated API key.
want to use a different API key, delete the {infer} model and recreate it with If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
the same name and the updated API key.
`url`::: `url`:::
(Required, string) (Required, string)
The URL endpoint to use for the requests. The URL endpoint to use for the requests.
`rate_limit`:::
(Optional, object)
By default, the `huggingface` service sets the number of requests allowed per minute to `3000`.
This helps to minimize the number of rate limit errors returned from Hugging Face.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
===== =====
+ +
.`service_settings` for the `openai` service .`service_settings` for the `openai` service
[%collapsible%closed] [%collapsible%closed]
===== =====
`api_key`::: `api_key`:::
(Required, string) (Required, string)
A valid API key of your OpenAI account. You can find your OpenAI API keys in A valid API key of your OpenAI account.
your OpenAI account under the You can find your OpenAI API keys in your OpenAI account under the
https://platform.openai.com/api-keys[API keys section]. https://platform.openai.com/api-keys[API keys section].
IMPORTANT: You need to provide the API key only once, during the {infer} model IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
creation. The <<get-inference-api>> does not retrieve your API key. After The <<get-inference-api>> does not retrieve your API key.
creating the {infer} model, you cannot change the associated API key. If you After creating the {infer} model, you cannot change the associated API key.
want to use a different API key, delete the {infer} model and recreate it with If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
the same name and the updated API key.
`model_id`::: `model_id`:::
(Required, string) (Required, string)
The name of the model to use for the {infer} task. Refer to the The name of the model to use for the {infer} task.
Refer to the
https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation] https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation]
for the list of available text embedding models. for the list of available text embedding models.
`organization_id`::: `organization_id`:::
(Optional, string) (Optional, string)
The unique identifier of your organization. You can find the Organization ID in The unique identifier of your organization.
your OpenAI account under You can find the Organization ID in your OpenAI account under
https://platform.openai.com/account/organization[**Settings** > **Organizations**]. https://platform.openai.com/account/organization[**Settings** > **Organizations**].
`url`::: `url`:::
(Optional, string) (Optional, string)
The URL endpoint to use for the requests. Can be changed for testing purposes. The URL endpoint to use for the requests.
Can be changed for testing purposes.
Defaults to `https://api.openai.com/v1/embeddings`. Defaults to `https://api.openai.com/v1/embeddings`.
`rate_limit`:::
(Optional, object)
The `openai` service sets a default number of requests allowed per minute depending on the task type.
For `text_embedding` it is set to `3000`.
For `completion` it is set to `500`.
This helps to minimize the number of rate limit errors returned from Azure.
To modify this, set the `requests_per_minute` setting of this object in your service settings:
+
[source,text]
----
"rate_limit": {
"requests_per_minute": <<number_of_requests>>
}
----
+
More information about the rate limits for OpenAI can be found in your https://platform.openai.com/account/limits[Account limits].
===== =====
`task_settings`:: `task_settings`::
(Optional, object) (Optional, object)
Settings to configure the {infer} task. These settings are specific to the Settings to configure the {infer} task.
These settings are specific to the
`<task_type>` you specified. `<task_type>` you specified.
+ +
.`task_settings` for the `completion` task type .`task_settings` for the `completion` task type
[%collapsible%closed] [%collapsible%closed]
===== =====
`do_sample`::: `do_sample`:::
(Optional, float) (Optional, float)
For the `azureaistudio` service only. For the `azureaistudio` service only.
@ -358,8 +420,8 @@ Defaults to 64.
`user`::: `user`:::
(Optional, string) (Optional, string)
For `openai` service only. Specifies the user issuing the request, which can be For `openai` service only.
used for abuse detection. Specifies the user issuing the request, which can be used for abuse detection.
`temperature`::: `temperature`:::
(Optional, float) (Optional, float)
@ -378,45 +440,46 @@ Should not be used if `temperature` is specified.
.`task_settings` for the `rerank` task type .`task_settings` for the `rerank` task type
[%collapsible%closed] [%collapsible%closed]
===== =====
`return_documents`:: `return_documents`::
(Optional, boolean) (Optional, boolean)
For `cohere` service only. Specify whether to return doc text within the For `cohere` service only.
results. Specify whether to return doc text within the results.
`top_n`:: `top_n`::
(Optional, integer) (Optional, integer)
The number of most relevant documents to return, defaults to the number of the The number of most relevant documents to return, defaults to the number of the documents.
documents.
===== =====
+ +
.`task_settings` for the `text_embedding` task type .`task_settings` for the `text_embedding` task type
[%collapsible%closed] [%collapsible%closed]
===== =====
`input_type`::: `input_type`:::
(Optional, string) (Optional, string)
For `cohere` service only. Specifies the type of input passed to the model. For `cohere` service only.
Specifies the type of input passed to the model.
Valid values are: Valid values are:
* `classification`: use it for embeddings passed through a text classifier. * `classification`: use it for embeddings passed through a text classifier.
* `clusterning`: use it for the embeddings run through a clustering algorithm. * `clusterning`: use it for the embeddings run through a clustering algorithm.
* `ingest`: use it for storing document embeddings in a vector database. * `ingest`: use it for storing document embeddings in a vector database.
* `search`: use it for storing embeddings of search queries run against a * `search`: use it for storing embeddings of search queries run against a vector database to find relevant documents.
vector database to find relevant documents.
`truncate`::: `truncate`:::
(Optional, string) (Optional, string)
For `cohere` service only. Specifies how the API handles inputs longer than the For `cohere` service only.
maximum token length. Defaults to `END`. Valid values are: Specifies how the API handles inputs longer than the maximum token length.
* `NONE`: when the input exceeds the maximum input token length an error is Defaults to `END`.
returned. Valid values are:
* `START`: when the input exceeds the maximum input token length the start of * `NONE`: when the input exceeds the maximum input token length an error is returned.
the input is discarded. * `START`: when the input exceeds the maximum input token length the start of the input is discarded.
* `END`: when the input exceeds the maximum input token length the end of * `END`: when the input exceeds the maximum input token length the end of the input is discarded.
the input is discarded.
`user`::: `user`:::
(optional, string) (optional, string)
For `openai`, `azureopenai` and `azureaistudio` services only. Specifies the user issuing the For `openai`, `azureopenai` and `azureaistudio` services only.
request, which can be used for abuse detection. Specifies the user issuing the request, which can be used for abuse detection.
===== =====
[discrete] [discrete]
@ -470,7 +533,6 @@ PUT _inference/completion/azure_ai_studio_completion
The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer]. The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer].
[discrete] [discrete]
[[inference-example-azureopenai]] [[inference-example-azureopenai]]
===== Azure OpenAI service ===== Azure OpenAI service
@ -519,7 +581,6 @@ The list of chat completion models that you can choose from in your Azure OpenAI
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models] * https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models]
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5] * https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5]
[discrete] [discrete]
[[inference-example-cohere]] [[inference-example-cohere]]
===== Cohere service ===== Cohere service
@ -565,7 +626,6 @@ PUT _inference/rerank/cohere-rerank
For more examples, also review the For more examples, also review the
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
[discrete] [discrete]
[[inference-example-e5]] [[inference-example-e5]]
===== E5 via the `elasticsearch` service ===== E5 via the `elasticsearch` service
@ -586,10 +646,9 @@ PUT _inference/text_embedding/my-e5-model
} }
------------------------------------------------------------ ------------------------------------------------------------
// TEST[skip:TBD] // TEST[skip:TBD]
<1> The `model_id` must be the ID of one of the built-in E5 models. Valid values <1> The `model_id` must be the ID of one of the built-in E5 models.
are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. For Valid values are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`.
further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. For further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
[discrete] [discrete]
[[inference-example-elser]] [[inference-example-elser]]
@ -597,8 +656,7 @@ further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
The following example shows how to create an {infer} endpoint called The following example shows how to create an {infer} endpoint called
`my-elser-model` to perform a `sparse_embedding` task type. `my-elser-model` to perform a `sparse_embedding` task type.
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more info.
info.
[source,console] [source,console]
------------------------------------------------------------ ------------------------------------------------------------
@ -672,16 +730,17 @@ PUT _inference/text_embedding/hugging-face-embeddings
} }
------------------------------------------------------------ ------------------------------------------------------------
// TEST[skip:TBD] // TEST[skip:TBD]
<1> A valid Hugging Face access token. You can find on the <1> A valid Hugging Face access token.
You can find on the
https://huggingface.co/settings/tokens[settings page of your account]. https://huggingface.co/settings/tokens[settings page of your account].
<2> The {infer} endpoint URL you created on Hugging Face. <2> The {infer} endpoint URL you created on Hugging Face.
Create a new {infer} endpoint on Create a new {infer} endpoint on
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an endpoint URL.
endpoint URL. Select the model you want to use on the new endpoint creation page Select the model you want to use on the new endpoint creation page - for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
- for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings` task under the Advanced configuration section.
task under the Advanced configuration section. Create the endpoint. Copy the URL Create the endpoint.
after the endpoint initialization has been finished. Copy the URL after the endpoint initialization has been finished.
[discrete] [discrete]
[[inference-example-hugging-face-supported-models]] [[inference-example-hugging-face-supported-models]]
@ -695,7 +754,6 @@ The list of recommended models for the Hugging Face service:
* https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base] * https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base]
* https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small] * https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small]
[discrete] [discrete]
[[inference-example-eland]] [[inference-example-eland]]
===== Models uploaded by Eland via the elasticsearch service ===== Models uploaded by Eland via the elasticsearch service
@ -716,11 +774,9 @@ PUT _inference/text_embedding/my-msmarco-minilm-model
} }
------------------------------------------------------------ ------------------------------------------------------------
// TEST[skip:TBD] // TEST[skip:TBD]
<1> The `model_id` must be the ID of a text embedding model which has already <1> The `model_id` must be the ID of a text embedding model which has already been
been
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
[discrete] [discrete]
[[inference-example-openai]] [[inference-example-openai]]
===== OpenAI service ===== OpenAI service
@ -756,4 +812,3 @@ PUT _inference/completion/openai-completion
} }
------------------------------------------------------------ ------------------------------------------------------------
// TEST[skip:TBD] // TEST[skip:TBD]

View file

@ -88,6 +88,13 @@ public class TimeValue implements Comparable<TimeValue> {
return new TimeValue(days, TimeUnit.DAYS); return new TimeValue(days, TimeUnit.DAYS);
} }
/**
* @return the {@link TimeValue} object that has the least duration.
*/
public static TimeValue min(TimeValue time1, TimeValue time2) {
return time1.compareTo(time2) < 0 ? time1 : time2;
}
/** /**
* @return the unit used for the this time value, see {@link #duration()} * @return the unit used for the this time value, see {@link #duration()}
*/ */

View file

@ -17,6 +17,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.object.HasToString.hasToString; import static org.hamcrest.object.HasToString.hasToString;
@ -231,6 +232,12 @@ public class TimeValueTests extends ESTestCase {
assertThat(ex.getMessage(), containsString("duration cannot be negative")); assertThat(ex.getMessage(), containsString("duration cannot be negative"));
} }
public void testMin() {
assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0)));
assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1)));
assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE));
}
private TimeUnit randomTimeUnitObject() { private TimeUnit randomTimeUnitObject() {
return randomFrom( return randomFrom(
TimeUnit.NANOSECONDS, TimeUnit.NANOSECONDS,

View file

@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor {
private final ServiceComponents serviceComponents; private final ServiceComponents serviceComponents;
public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
// TODO Batching - accept a class that can handle batching
this.sender = Objects.requireNonNull(sender); this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents); this.serviceComponents = Objects.requireNonNull(serviceComponents);
} }

View file

@ -36,6 +36,7 @@ public class CohereEmbeddingsAction implements ExecutableAction {
model.getServiceSettings().getCommonSettings().uri(), model.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings" "Cohere embeddings"
); );
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -37,17 +36,16 @@ public class AzureAiStudioChatCompletionRequestManager extends AzureAiStudioRequ
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input); AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
private static ResponseHandler createCompletionHandler() { private static ResponseHandler createCompletionHandler() {

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -41,17 +40,16 @@ public class AzureAiStudioEmbeddingsRequestManager extends AzureAiStudioRequestM
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model); AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
private static ResponseHandler createEmbeddingsHandler() { private static ResponseHandler createEmbeddingsHandler() {

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -43,16 +42,15 @@ public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManag
} }
@Override @Override
public Runnable create( public void execute(
@Nullable String query, @Nullable String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model); AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -55,16 +54,15 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model); AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -38,11 +38,16 @@ abstract class BaseRequestManager implements RequestManager {
@Override @Override
public Object rateLimitGrouping() { public Object rateLimitGrouping() {
return rateLimitGroup; // It's possible that two inference endpoints have the same information defining the group but have different
// rate limits then they should be in different groups otherwise whoever initially created the group will set
// the rate and the other inference endpoint's rate will be ignored
return new EndpointGrouping(rateLimitGroup, rateLimitSettings);
} }
@Override @Override
public RateLimitSettings rateLimitSettings() { public RateLimitSettings rateLimitSettings() {
return rateLimitSettings; return rateLimitSettings;
} }
private record EndpointGrouping(Object group, RateLimitSettings settings) {}
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -46,16 +45,15 @@ public class CohereCompletionRequestManager extends CohereRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
CohereCompletionRequest request = new CohereCompletionRequest(input, model); CohereCompletionRequest request = new CohereCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -44,16 +43,15 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model); CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -44,16 +43,15 @@ public class CohereRerankRequestManager extends CohereRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
CohereRerankRequest request = new CohereRerankRequest(query, input, model); CohereRerankRequest request = new CohereRerankRequest(query, input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -23,7 +23,6 @@ record ExecutableInferenceRequest(
RequestSender requestSender, RequestSender requestSender,
Logger logger, Logger logger,
Request request, Request request,
HttpClientContext context,
ResponseHandler responseHandler, ResponseHandler responseHandler,
Supplier<Boolean> hasFinished, Supplier<Boolean> hasFinished,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
@ -34,7 +33,7 @@ record ExecutableInferenceRequest(
var inferenceEntityId = request.createHttpRequest().inferenceEntityId(); var inferenceEntityId = request.createHttpRequest().inferenceEntityId();
try { try {
requestSender.send(logger, request, context, hasFinished, responseHandler, listener); requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener);
} catch (Exception e) { } catch (Exception e) {
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId); var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
logger.warn(errorMessage, e); logger.warn(errorMessage, e);

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -42,15 +41,14 @@ public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioReques
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -48,17 +47,16 @@ public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioReques
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -15,6 +15,8 @@ import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
@ -39,30 +41,28 @@ public class HttpRequestSender implements Sender {
private final ServiceComponents serviceComponents; private final ServiceComponents serviceComponents;
private final HttpClientManager httpClientManager; private final HttpClientManager httpClientManager;
private final ClusterService clusterService; private final ClusterService clusterService;
private final SingleRequestManager requestManager; private final RequestSender requestSender;
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) { public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
this.serviceComponents = Objects.requireNonNull(serviceComponents); this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.httpClientManager = Objects.requireNonNull(httpClientManager); this.httpClientManager = Objects.requireNonNull(httpClientManager);
this.clusterService = Objects.requireNonNull(clusterService); this.clusterService = Objects.requireNonNull(clusterService);
var requestSender = new RetryingHttpSender( requestSender = new RetryingHttpSender(
this.httpClientManager.getHttpClient(), this.httpClientManager.getHttpClient(),
serviceComponents.throttlerManager(), serviceComponents.throttlerManager(),
new RetrySettings(serviceComponents.settings(), clusterService), new RetrySettings(serviceComponents.settings(), clusterService),
serviceComponents.threadPool() serviceComponents.threadPool()
); );
requestManager = new SingleRequestManager(requestSender);
} }
public Sender createSender(String serviceName) { public Sender createSender() {
return new HttpRequestSender( return new HttpRequestSender(
serviceName,
serviceComponents.threadPool(), serviceComponents.threadPool(),
httpClientManager, httpClientManager,
clusterService, clusterService,
serviceComponents.settings(), serviceComponents.settings(),
requestManager requestSender
); );
} }
} }
@ -71,26 +71,24 @@ public class HttpRequestSender implements Sender {
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final HttpClientManager manager; private final HttpClientManager manager;
private final RequestExecutorService service; private final RequestExecutor service;
private final AtomicBoolean started = new AtomicBoolean(false); private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(1); private final CountDownLatch startCompleted = new CountDownLatch(1);
private HttpRequestSender( private HttpRequestSender(
String serviceName,
ThreadPool threadPool, ThreadPool threadPool,
HttpClientManager httpClientManager, HttpClientManager httpClientManager,
ClusterService clusterService, ClusterService clusterService,
Settings settings, Settings settings,
SingleRequestManager requestManager RequestSender requestSender
) { ) {
this.threadPool = Objects.requireNonNull(threadPool); this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager); this.manager = Objects.requireNonNull(httpClientManager);
service = new RequestExecutorService( service = new RequestExecutorService(
serviceName,
threadPool, threadPool,
startCompleted, startCompleted,
new RequestExecutorServiceSettings(settings, clusterService), new RequestExecutorServiceSettings(settings, clusterService),
requestManager requestSender
); );
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -55,26 +54,17 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var truncatedInput = truncate(input, model.getTokenLimit()); var truncatedInput = truncate(input, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest( execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
requestSender,
logger,
request,
context,
responseHandler,
hasRequestCompletedFunction,
listener
);
} }
record RateLimitGrouping(int accountHash) { record RateLimitGrouping(int accountHash) {

View file

@ -19,9 +19,9 @@ import java.util.function.Supplier;
public interface InferenceRequest { public interface InferenceRequest {
/** /**
* Returns the creator that handles building an executable request based on the input provided. * Returns the manager that handles building and executing an inference request.
*/ */
RequestManager getRequestCreator(); RequestManager getRequestManager();
/** /**
* Returns the query associated with this request. Used for Rerank tasks. * Returns the query associated with this request. Used for Rerank tasks.

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -51,18 +50,17 @@ public class MistralEmbeddingsRequestManager extends BaseRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
record RateLimitGrouping(int keyHashCode) { record RateLimitGrouping(int keyHashCode) {

View file

@ -1,52 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import java.util.List;
import java.util.function.Supplier;
class NoopTask implements RejectableTask {
@Override
public RequestManager getRequestCreator() {
return null;
}
@Override
public String getQuery() {
return null;
}
@Override
public List<String> getInput() {
return null;
}
@Override
public ActionListener<InferenceServiceResults> getListener() {
return null;
}
@Override
public boolean hasCompleted() {
return true;
}
@Override
public Supplier<Boolean> getRequestCompletedFunction() {
return () -> true;
}
@Override
public void onRejection(Exception e) {
}
}

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -43,17 +42,16 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
@Nullable String query, @Nullable String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model); OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
private static ResponseHandler createCompletionHandler() { private static ResponseHandler createCompletionHandler() {

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -55,17 +54,16 @@ public class OpenAiEmbeddingsRequestManager extends OpenAiRequestManager {
} }
@Override @Override
public Runnable create( public void execute(
String query, String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
} }
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
@ -17,21 +16,31 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings; import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue;
import org.elasticsearch.xpack.inference.common.RateLimiter;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor; import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
/** /**
* A service for queuing and executing {@link RequestTask}. This class is useful because the * A service for queuing and executing {@link RequestTask}. This class is useful because the
@ -45,7 +54,18 @@ import static org.elasticsearch.core.Strings.format;
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
*/ */
class RequestExecutorService implements RequestExecutor { class RequestExecutorService implements RequestExecutor {
private static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> QUEUE_CREATOR =
/**
* Provides dependency injection mainly for testing
*/
interface Sleeper {
void sleep(TimeValue sleepTime) throws InterruptedException;
}
// default for tests
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
// default for tests
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
new AdjustableCapacityBlockingQueue.QueueCreator<>() { new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@Override @Override
public BlockingQueue<RejectableTask> create(int capacity) { public BlockingQueue<RejectableTask> create(int capacity) {
@ -65,86 +85,116 @@ class RequestExecutorService implements RequestExecutor {
} }
}; };
/**
* Provides dependency injection mainly for testing
*/
interface RateLimiterCreator {
RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit);
}
// default for testing
static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new;
private static final Logger logger = LogManager.getLogger(RequestExecutorService.class); private static final Logger logger = LogManager.getLogger(RequestExecutorService.class);
private final String serviceName; private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
private final AtomicBoolean running = new AtomicBoolean(true); private final ConcurrentMap<Object, RateLimitingEndpointHandler> rateLimitGroupings = new ConcurrentHashMap<>();
private final CountDownLatch terminationLatch = new CountDownLatch(1);
private final HttpClientContext httpContext;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final CountDownLatch startupLatch; private final CountDownLatch startupLatch;
private final BlockingQueue<Runnable> controlQueue = new LinkedBlockingQueue<>(); private final CountDownLatch terminationLatch = new CountDownLatch(1);
private final SingleRequestManager requestManager; private final RequestSender requestSender;
private final RequestExecutorServiceSettings settings;
private final Clock clock;
private final AtomicBoolean shutdown = new AtomicBoolean(false);
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
private final Sleeper sleeper;
private final RateLimiterCreator rateLimiterCreator;
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
private final AtomicBoolean started = new AtomicBoolean(false);
RequestExecutorService( RequestExecutorService(
String serviceName,
ThreadPool threadPool, ThreadPool threadPool,
@Nullable CountDownLatch startupLatch, @Nullable CountDownLatch startupLatch,
RequestExecutorServiceSettings settings, RequestExecutorServiceSettings settings,
SingleRequestManager requestManager RequestSender requestSender
) { ) {
this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager); this(
threadPool,
DEFAULT_QUEUE_CREATOR,
startupLatch,
settings,
requestSender,
Clock.systemUTC(),
DEFAULT_SLEEPER,
DEFAULT_RATE_LIMIT_CREATOR
);
} }
/**
* This constructor should only be used directly for testing.
*/
RequestExecutorService( RequestExecutorService(
String serviceName,
ThreadPool threadPool, ThreadPool threadPool,
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue, AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator,
@Nullable CountDownLatch startupLatch, @Nullable CountDownLatch startupLatch,
RequestExecutorServiceSettings settings, RequestExecutorServiceSettings settings,
SingleRequestManager requestManager RequestSender requestSender,
Clock clock,
Sleeper sleeper,
RateLimiterCreator rateLimiterCreator
) { ) {
this.serviceName = Objects.requireNonNull(serviceName);
this.threadPool = Objects.requireNonNull(threadPool); this.threadPool = Objects.requireNonNull(threadPool);
this.httpContext = HttpClientContext.create(); this.queueCreator = Objects.requireNonNull(queueCreator);
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
this.startupLatch = startupLatch; this.startupLatch = startupLatch;
this.requestManager = Objects.requireNonNull(requestManager); this.requestSender = Objects.requireNonNull(requestSender);
this.settings = Objects.requireNonNull(settings);
Objects.requireNonNull(settings); this.clock = Objects.requireNonNull(clock);
settings.registerQueueCapacityCallback(this::onCapacityChange); this.sleeper = Objects.requireNonNull(sleeper);
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
} }
private void onCapacityChange(int capacity) { public void shutdown() {
logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity)); if (shutdown.compareAndSet(false, true)) {
if (cancellableCleanupTask.get() != null) {
var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity)); logger.debug(() -> "Stopping clean up thread");
if (enqueuedCapacityCommand == false) { cancellableCleanupTask.get().cancel();
logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later."); }
} else {
// ensure that the task execution loop wakes up
queue.offer(new NoopTask());
} }
} }
private void updateCapacity(int newCapacity) { public boolean isShutdown() {
try { return shutdown.get();
queue.setCapacity(newCapacity); }
} catch (Exception e) {
logger.warn( public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName), return terminationLatch.await(timeout, unit);
e }
);
} public boolean isTerminated() {
return terminationLatch.getCount() == 0;
}
public int queueSize() {
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum();
} }
/** /**
* Begin servicing tasks. * Begin servicing tasks.
* <p>
* <b>Note: This should only be called once for the life of the object.</b>
* </p>
*/ */
public void start() { public void start() {
try { try {
assert started.get() == false : "start() can only be called once";
started.set(true);
startCleanupTask();
signalStartInitiated(); signalStartInitiated();
while (running.get()) { while (isShutdown() == false) {
handleTasks(); handleTasks();
} }
} catch (InterruptedException e) { } catch (InterruptedException e) {
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
} finally { } finally {
running.set(false); shutdown();
notifyRequestsOfShutdown(); notifyRequestsOfShutdown();
terminationLatch.countDown(); terminationLatch.countDown();
} }
@ -156,108 +206,68 @@ class RequestExecutorService implements RequestExecutor {
} }
} }
/** private void startCleanupTask() {
* Protects the task retrieval logic from an unexpected exception. assert cancellableCleanupTask.get() == null : "The clean up task can only be set once";
* cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL));
* @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to }
* shut down
*/ private Scheduler.Cancellable startCleanupThread(TimeValue interval) {
logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval));
return threadPool.scheduleWithFixedDelay(this::removeStaleGroupings, interval, threadPool.executor(UTILITY_THREAD_POOL_NAME));
}
// default for testing
void removeStaleGroupings() {
var now = Instant.now(clock);
for (var iter = rateLimitGroupings.values().iterator(); iter.hasNext();) {
var endpoint = iter.next();
// if the current time is after the last time the endpoint enqueued a request + allowed stale period then we'll remove it
if (now.isAfter(endpoint.timeOfLastEnqueue().plus(settings.getRateLimitGroupStaleDuration()))) {
endpoint.close();
iter.remove();
}
}
}
private void handleTasks() throws InterruptedException { private void handleTasks() throws InterruptedException {
try { var timeToWait = settings.getTaskPollFrequency();
RejectableTask task = queue.take(); for (var endpoint : rateLimitGroupings.values()) {
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
var command = controlQueue.poll();
if (command != null) {
command.run();
}
// TODO add logic to complete pending items in the queue before shutting down
if (running.get() == false) {
logger.debug(() -> format("Http executor service [%s] exiting", serviceName));
rejectTaskBecauseOfShutdown(task);
} else {
executeTask(task);
}
} catch (InterruptedException e) {
throw e;
} catch (Exception e) {
logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e);
} }
sleeper.sleep(timeToWait);
} }
private void executeTask(RejectableTask task) { private void notifyRequestsOfShutdown() {
try {
requestManager.execute(task, httpContext);
} catch (Exception e) {
logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e);
}
}
private synchronized void notifyRequestsOfShutdown() {
assert isShutdown() : "Requests should only be notified if the executor is shutting down"; assert isShutdown() : "Requests should only be notified if the executor is shutting down";
try { for (var endpoint : rateLimitGroupings.values()) {
List<RejectableTask> notExecuted = new ArrayList<>(); endpoint.notifyRequestsOfShutdown();
queue.drainTo(notExecuted);
rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown);
} catch (Exception e) {
logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName));
} }
} }
private void rejectTaskBecauseOfShutdown(RejectableTask task) { // default for testing
try { Integer remainingQueueCapacity(RequestManager requestManager) {
task.onRejection( var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping());
new EsRejectedExecutionException(
format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName), if (endpoint == null) {
true return null;
)
);
} catch (Exception e) {
logger.warn(
format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName)
);
} }
return endpoint.remainingCapacity();
} }
private void rejectTasks(List<RejectableTask> tasks, Consumer<RejectableTask> rejectionFunction) { // default for testing
for (var task : tasks) { int numberOfRateLimitGroups() {
rejectionFunction.accept(task); return rateLimitGroupings.size();
}
}
public int queueSize() {
return queue.size();
}
@Override
public void shutdown() {
if (running.compareAndSet(true, false)) {
// if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns
queue.offer(new NoopTask());
}
}
@Override
public boolean isShutdown() {
return running.get() == false;
}
@Override
public boolean isTerminated() {
return terminationLatch.getCount() == 0;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return terminationLatch.await(timeout, unit);
} }
/** /**
* Execute the request at some point in the future. * Execute the request at some point in the future.
* *
* @param requestCreator the http request to send * @param requestManager the http request to send
* @param inferenceInputs the inputs to send in the request * @param inferenceInputs the inputs to send in the request
* @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the * @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the
* listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}. * listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}.
@ -265,13 +275,13 @@ class RequestExecutorService implements RequestExecutor {
* @param listener an {@link ActionListener<InferenceServiceResults>} for the response or failure * @param listener an {@link ActionListener<InferenceServiceResults>} for the response or failure
*/ */
public void execute( public void execute(
RequestManager requestCreator, RequestManager requestManager,
InferenceInputs inferenceInputs, InferenceInputs inferenceInputs,
@Nullable TimeValue timeout, @Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
) { ) {
var task = new RequestTask( var task = new RequestTask(
requestCreator, requestManager,
inferenceInputs, inferenceInputs,
timeout, timeout,
threadPool, threadPool,
@ -280,38 +290,230 @@ class RequestExecutorService implements RequestExecutor {
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext())
); );
completeExecution(task); var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> {
} var endpointHandler = new RateLimitingEndpointHandler(
Integer.toString(requestManager.rateLimitGrouping().hashCode()),
private void completeExecution(RequestTask task) { queueCreator,
if (isShutdown()) { settings,
EsRejectedExecutionException rejected = new EsRejectedExecutionException( requestSender,
format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName), clock,
true requestManager.rateLimitSettings(),
this::isShutdown,
rateLimiterCreator
); );
task.onRejection(rejected); endpointHandler.init();
return; return endpointHandler;
} });
boolean added = queue.offer(task); endpoint.enqueue(task);
if (added == false) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
format("Failed to execute task because the http executor service [%s] queue is full", serviceName),
false
);
task.onRejection(rejected);
} else if (isShutdown()) {
// It is possible that a shutdown and notification request occurred after we initially checked for shutdown above
// If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if
// we shut down and if so we'll redo the notification
notifyRequestsOfShutdown();
}
} }
// default for testing /**
int remainingQueueCapacity() { * Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests.
return queue.remainingCapacity(); * This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent
* as soon as a thread is available.
*/
private static class RateLimitingEndpointHandler {
private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE;
private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO;
private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class);
private static final int ACCUMULATED_TOKENS_LIMIT = 1;
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
private final Supplier<Boolean> isShutdownMethod;
private final RequestSender requestSender;
private final String id;
private final AtomicReference<Instant> timeOfLastEnqueue = new AtomicReference<>();
private final Clock clock;
private final RateLimiter rateLimiter;
private final RequestExecutorServiceSettings requestExecutorServiceSettings;
RateLimitingEndpointHandler(
String id,
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
RequestExecutorServiceSettings settings,
RequestSender requestSender,
Clock clock,
RateLimitSettings rateLimitSettings,
Supplier<Boolean> isShutdownMethod,
RateLimiterCreator rateLimiterCreator
) {
this.requestExecutorServiceSettings = Objects.requireNonNull(settings);
this.id = Objects.requireNonNull(id);
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
this.requestSender = Objects.requireNonNull(requestSender);
this.clock = Objects.requireNonNull(clock);
this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod);
Objects.requireNonNull(rateLimitSettings);
Objects.requireNonNull(rateLimiterCreator);
rateLimiter = rateLimiterCreator.create(
ACCUMULATED_TOKENS_LIMIT,
rateLimitSettings.requestsPerTimeUnit(),
rateLimitSettings.timeUnit()
);
}
public void init() {
requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange);
}
private void onCapacityChange(int capacity) {
logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity));
try {
queue.setCapacity(capacity);
} catch (Exception e) {
logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e);
}
}
public int queueSize() {
return queue.size();
}
public boolean isShutdown() {
return isShutdownMethod.get();
}
public Instant timeOfLastEnqueue() {
return timeOfLastEnqueue.get();
}
public synchronized TimeValue executeEnqueuedTask() {
try {
return executeEnqueuedTaskInternal();
} catch (Exception e) {
logger.warn(format("Executor service grouping [%s] failed to execute request", id), e);
// we tried to do some work but failed, so we'll say we did something to try looking for more work
return EXECUTED_A_TASK;
}
}
private TimeValue executeEnqueuedTaskInternal() {
var timeBeforeAvailableToken = rateLimiter.timeToReserve(1);
if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) {
return timeBeforeAvailableToken;
}
var task = queue.poll();
// TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks
// So we'll need to check for null and call a helper method executePreparedTasks()
if (shouldExecuteTask(task) == false) {
return NO_TASKS_AVAILABLE;
}
// We should never have to wait because we checked above
var reserveRes = rateLimiter.reserve(1);
assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have";
task.getRequestManager()
.execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener());
return EXECUTED_A_TASK;
}
private static boolean shouldExecuteTask(RejectableTask task) {
return task != null && isNoopRequest(task) == false && task.hasCompleted() == false;
}
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
return inferenceRequest.getRequestManager() == null
|| inferenceRequest.getInput() == null
|| inferenceRequest.getListener() == null;
}
private static boolean shouldExecuteImmediately(TimeValue delay) {
return delay.duration() == 0;
}
public void enqueue(RequestTask task) {
timeOfLastEnqueue.set(Instant.now(clock));
if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
format(
"Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown",
task.getRequestManager().inferenceEntityId(),
id
),
true
);
task.onRejection(rejected);
return;
}
var addedToQueue = queue.offer(task);
if (addedToQueue == false) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
format(
"Failed to execute task for inference id [%s] because the request service [%s] queue is full",
task.getRequestManager().inferenceEntityId(),
id
),
false
);
task.onRejection(rejected);
} else if (isShutdown()) {
notifyRequestsOfShutdown();
}
}
public synchronized void notifyRequestsOfShutdown() {
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
try {
List<RejectableTask> notExecuted = new ArrayList<>();
queue.drainTo(notExecuted);
rejectTasks(notExecuted);
} catch (Exception e) {
logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id));
}
}
private void rejectTasks(List<RejectableTask> tasks) {
for (var task : tasks) {
rejectTaskForShutdown(task);
}
}
private void rejectTaskForShutdown(RejectableTask task) {
try {
task.onRejection(
new EsRejectedExecutionException(
format(
"Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request",
id,
task.getRequestManager().inferenceEntityId()
),
true
)
);
} catch (Exception e) {
logger.warn(
format(
"Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown",
task.getRequestManager().inferenceEntityId(),
id
)
);
}
}
public int remainingCapacity() {
return queue.remainingCapacity();
}
public void close() {
requestExecutorServiceSettings.deregisterQueueCapacityCallback(id);
}
} }
} }

View file

@ -10,9 +10,12 @@ package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import java.util.ArrayList; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer; import java.util.function.Consumer;
public class RequestExecutorServiceSettings { public class RequestExecutorServiceSettings {
@ -29,37 +32,108 @@ public class RequestExecutorServiceSettings {
Setting.Property.Dynamic Setting.Property.Dynamic
); );
private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50);
/**
* Defines how often all the rate limit groups are polled for tasks. Setting this to very low number could result
* in a busy loop if there are no tasks available to handle.
*/
static final Setting<TimeValue> TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting(
"xpack.inference.http.request_executor.task_poll_frequency",
DEFAULT_TASK_POLL_FREQUENCY_TIME,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
/**
* Defines how often a thread will check for rate limit groups that are stale.
*/
static final Setting<TimeValue> RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING = Setting.timeSetting(
"xpack.inference.http.request_executor.rate_limit_group_cleanup_interval",
DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION = TimeValue.timeValueDays(10);
/**
* Defines the amount of time it takes to classify a rate limit group as stale. Once it is classified as stale,
* it can be removed when the cleanup thread executes.
*/
static final Setting<TimeValue> RATE_LIMIT_GROUP_STALE_DURATION_SETTING = Setting.timeSetting(
"xpack.inference.http.request_executor.rate_limit_group_stale_duration",
DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static List<Setting<?>> getSettingsDefinitions() { public static List<Setting<?>> getSettingsDefinitions() {
return List.of(TASK_QUEUE_CAPACITY_SETTING); return List.of(
TASK_QUEUE_CAPACITY_SETTING,
TASK_POLL_FREQUENCY_SETTING,
RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING,
RATE_LIMIT_GROUP_STALE_DURATION_SETTING
);
} }
private volatile int queueCapacity; private volatile int queueCapacity;
private final List<Consumer<Integer>> queueCapacityCallbacks = new ArrayList<Consumer<Integer>>(); private volatile TimeValue taskPollFrequency;
private volatile Duration rateLimitGroupStaleDuration;
private final ConcurrentMap<String, Consumer<Integer>> queueCapacityCallbacks = new ConcurrentHashMap<>();
public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) { public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) {
queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings); queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings);
taskPollFrequency = TASK_POLL_FREQUENCY_SETTING.get(settings);
setRateLimitGroupStaleDuration(RATE_LIMIT_GROUP_STALE_DURATION_SETTING.get(settings));
addSettingsUpdateConsumers(clusterService); addSettingsUpdateConsumers(clusterService);
} }
private void addSettingsUpdateConsumers(ClusterService clusterService) { private void addSettingsUpdateConsumers(ClusterService clusterService) {
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity); clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity);
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency);
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(RATE_LIMIT_GROUP_STALE_DURATION_SETTING, this::setRateLimitGroupStaleDuration);
} }
// default for testing // default for testing
void setQueueCapacity(int queueCapacity) { void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity; this.queueCapacity = queueCapacity;
for (var callback : queueCapacityCallbacks) { for (var callback : queueCapacityCallbacks.values()) {
callback.accept(queueCapacity); callback.accept(queueCapacity);
} }
} }
void registerQueueCapacityCallback(Consumer<Integer> onChangeCapacityCallback) { private void setTaskPollFrequency(TimeValue taskPollFrequency) {
queueCapacityCallbacks.add(onChangeCapacityCallback); this.taskPollFrequency = taskPollFrequency;
}
private void setRateLimitGroupStaleDuration(TimeValue staleDuration) {
rateLimitGroupStaleDuration = toDuration(staleDuration);
}
private static Duration toDuration(TimeValue timeValue) {
return Duration.of(timeValue.duration(), timeValue.timeUnit().toChronoUnit());
}
void registerQueueCapacityCallback(String id, Consumer<Integer> onChangeCapacityCallback) {
queueCapacityCallbacks.put(id, onChangeCapacityCallback);
}
void deregisterQueueCapacityCallback(String id) {
queueCapacityCallbacks.remove(id);
} }
int getQueueCapacity() { int getQueueCapacity() {
return queueCapacity; return queueCapacity;
} }
TimeValue getTaskPollFrequency() {
return taskPollFrequency;
}
Duration getRateLimitGroupStaleDuration() {
return rateLimitGroupStaleDuration;
}
} }

View file

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.inference.external.http.sender; package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
@ -21,14 +20,17 @@ import java.util.function.Supplier;
* A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service. * A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service.
*/ */
public interface RequestManager extends RateLimitable { public interface RequestManager extends RateLimitable {
Runnable create( void execute(
@Nullable String query, @Nullable String query,
List<String> input, List<String> input,
RequestSender requestSender, RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction, Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener ActionListener<InferenceServiceResults> listener
); );
// TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual
// managers to implement their own batching
// executePreparedRequest() which will execute all prepared requests aka sends the batch
String inferenceEntityId(); String inferenceEntityId();
} }

View file

@ -111,7 +111,7 @@ class RequestTask implements RejectableTask {
} }
@Override @Override
public RequestManager getRequestCreator() { public RequestManager getRequestManager() {
return requestCreator; return requestCreator;
} }
} }

View file

@ -1,48 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import java.util.Objects;
/**
* Handles executing a single inference request at a time.
*/
public class SingleRequestManager {
protected RetryingHttpSender requestSender;
public SingleRequestManager(RetryingHttpSender requestSender) {
this.requestSender = Objects.requireNonNull(requestSender);
}
public void execute(InferenceRequest inferenceRequest, HttpClientContext context) {
if (isNoopRequest(inferenceRequest) || inferenceRequest.hasCompleted()) {
return;
}
inferenceRequest.getRequestCreator()
.create(
inferenceRequest.getQuery(),
inferenceRequest.getInput(),
requestSender,
inferenceRequest.getRequestCompletedFunction(),
context,
inferenceRequest.getListener()
)
.run();
}
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
return inferenceRequest.getRequestCreator() == null
|| inferenceRequest.getInput() == null
|| inferenceRequest.getListener() == null;
}
}

View file

@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService {
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
Objects.requireNonNull(factory); Objects.requireNonNull(factory);
sender = factory.createSender(name()); sender = factory.createSender();
this.serviceComponents = Objects.requireNonNull(serviceComponents); this.serviceComponents = Objects.requireNonNull(serviceComponents);
} }

View file

@ -56,7 +56,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completio
public class AzureAiStudioService extends SenderService { public class AzureAiStudioService extends SenderService {
private static final String NAME = "azureaistudio"; static final String NAME = "azureaistudio";
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents); super(factory, serviceComponents);

View file

@ -44,7 +44,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
ConfigurationParseContext context ConfigurationParseContext context
) { ) {
String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
AzureAiStudioService.NAME,
context
);
AzureAiStudioEndpointType endpointType = extractRequiredEnum( AzureAiStudioEndpointType endpointType = extractRequiredEnum(
map, map,
ENDPOINT_TYPE_FIELD, ENDPOINT_TYPE_FIELD,
@ -118,13 +124,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
protected void addXContentFields(XContentBuilder builder, Params params) throws IOException { protected void addXContentFields(XContentBuilder builder, Params params) throws IOException {
this.addExposedXContentFields(builder, params); this.addExposedXContentFields(builder, params);
rateLimitSettings.toXContent(builder, params);
} }
protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException { protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException {
builder.field(TARGET_FIELD, this.target); builder.field(TARGET_FIELD, this.target);
builder.field(PROVIDER_FIELD, this.provider); builder.field(PROVIDER_FIELD, this.provider);
builder.field(ENDPOINT_TYPE_FIELD, this.endpointType); builder.field(ENDPOINT_TYPE_FIELD, this.endpointType);
rateLimitSettings.toXContent(builder, params);
} }
} }

View file

@ -135,7 +135,15 @@ public class AzureOpenAiService extends SenderService {
); );
} }
case COMPLETION -> { case COMPLETION -> {
return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings); return new AzureOpenAiCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);
} }
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
@ -37,13 +38,14 @@ public class AzureOpenAiCompletionModel extends AzureOpenAiModel {
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets @Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
service, service,
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings), AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context),
AzureOpenAiCompletionTaskSettings.fromMap(taskSettings), AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
AzureOpenAiSecretSettings.fromMap(secrets) AzureOpenAiSecretSettings.fromMap(secrets)
); );

View file

@ -17,7 +17,9 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -55,10 +57,10 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
*/ */
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map) { public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
var settings = fromMap(map, validationException); var settings = fromMap(map, validationException, context);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -69,12 +71,19 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap(
Map<String, Object> map, Map<String, Object> map,
ValidationException validationException ValidationException validationException,
ConfigurationParseContext context
) { ) {
String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
AzureOpenAiService.NAME,
context
);
return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings);
} }
@ -137,7 +146,6 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -148,6 +156,7 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
builder.field(RESOURCE_NAME, resourceName); builder.field(RESOURCE_NAME, resourceName);
builder.field(DEPLOYMENT_ID, deploymentId); builder.field(DEPLOYMENT_ID, deploymentId);
builder.field(API_VERSION, apiVersion); builder.field(API_VERSION, apiVersion);
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -90,7 +91,13 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
AzureOpenAiService.NAME,
context
);
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
@ -245,8 +252,6 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
builder.endObject(); builder.endObject();
@ -268,6 +273,7 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
if (similarity != null) { if (similarity != null) {
builder.field(SIMILARITY, similarity); builder.field(SIMILARITY, similarity);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -51,6 +51,11 @@ import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFie
public class CohereService extends SenderService { public class CohereService extends SenderService {
public static final String NAME = "cohere"; public static final String NAME = "cohere";
// TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to
// the Cohere*RequestManager via the CohereActionCreator class
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
// on every request
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents); super(factory, serviceComponents);
} }
@ -131,7 +136,15 @@ public class CohereService extends SenderService {
context context
); );
case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings); case COMPLETION -> new CohereCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}; };
} }

View file

@ -58,7 +58,13 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
CohereService.NAME,
context
);
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -173,10 +179,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
} }
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
toXContentFragmentOfExposedFields(builder, params); return toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
return builder;
} }
@Override @Override
@ -196,6 +199,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
if (modelId != null) { if (modelId != null) {
builder.field(MODEL_ID, modelId); builder.field(MODEL_ID, modelId);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor; import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.cohere.CohereModel; import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -30,13 +31,14 @@ public class CohereCompletionModel extends CohereModel {
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets @Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
modelId, modelId,
taskType, taskType,
service, service,
CohereCompletionServiceSettings.fromMap(serviceSettings), CohereCompletionServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE, EmptyTaskSettings.INSTANCE,
DefaultSecretSettings.fromMap(secrets) DefaultSecretSettings.fromMap(secrets)
); );

View file

@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -39,12 +41,18 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
// 10K requests per minute // 10K requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map) { public static CohereCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
CohereService.NAME,
context
);
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
@ -94,7 +102,6 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -127,6 +134,7 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
if (modelId != null) { if (modelId != null) {
builder.field(MODEL_ID, modelId); builder.field(MODEL_ID, modelId);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -108,7 +108,8 @@ public class GoogleAiStudioService extends SenderService {
NAME, NAME,
serviceSettings, serviceSettings,
taskSettings, taskSettings,
secretSettings secretSettings,
context
); );
case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel( case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel(
inferenceEntityId, inferenceEntityId,
@ -116,7 +117,8 @@ public class GoogleAiStudioService extends SenderService {
NAME, NAME,
serviceSettings, serviceSettings,
taskSettings, taskSettings,
secretSettings secretSettings,
context
); );
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}; };

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -37,13 +38,14 @@ public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel {
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
Map<String, Object> secrets Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
service, service,
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings), GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE, EmptyTaskSettings.INSTANCE,
DefaultSecretSettings.fromMap(secrets) DefaultSecretSettings.fromMap(secrets)
); );

View file

@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -40,11 +42,17 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
*/ */
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map) { public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
GoogleAiStudioService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -82,7 +90,6 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -107,6 +114,7 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
@Override @Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID, modelId); builder.field(MODEL_ID, modelId);
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -37,13 +38,14 @@ public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel {
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
Map<String, Object> secrets Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
service, service,
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings), GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE, EmptyTaskSettings.INSTANCE,
DefaultSecretSettings.fromMap(secrets) DefaultSecretSettings.fromMap(secrets)
); );

View file

@ -18,7 +18,9 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -47,7 +49,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
*/ */
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map) { public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -59,7 +61,13 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
); );
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
GoogleAiStudioService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -134,7 +142,6 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -174,6 +181,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
if (similarity != null) { if (similarity != null) {
builder.field(SIMILARITY, similarity); builder.field(SIMILARITY, similarity);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
@ -62,7 +63,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
taskType, taskType,
serviceSettingsMap, serviceSettingsMap,
serviceSettingsMap, serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()) TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
); );
throwIfNotEmptyMap(config, name()); throwIfNotEmptyMap(config, name());
@ -89,7 +91,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
taskType, taskType,
serviceSettingsMap, serviceSettingsMap,
secretSettingsMap, secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()) parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
); );
} }
@ -97,7 +100,14 @@ public abstract class HuggingFaceBaseService extends SenderService {
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) { public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
return createModel(inferenceEntityId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(inferenceEntityId, name())); return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
);
} }
protected abstract HuggingFaceModel createModel( protected abstract HuggingFaceModel createModel(
@ -105,7 +115,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
TaskType taskType, TaskType taskType,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> secretSettings, Map<String, Object> secretSettings,
String failureMessage String failureMessage,
ConfigurationParseContext context
); );
@Override @Override

View file

@ -16,6 +16,7 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
@ -36,11 +37,19 @@ public class HuggingFaceService extends HuggingFaceBaseService {
TaskType taskType, TaskType taskType,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secretSettings, @Nullable Map<String, Object> secretSettings,
String failureMessage String failureMessage,
ConfigurationParseContext context
) { ) {
return switch (taskType) { return switch (taskType) {
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}; };
} }

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -43,14 +44,20 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
// 3000 requests per minute // 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map) { public static HuggingFaceServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
var uri = extractUri(map, URL, validationException); var uri = extractUri(map, URL, validationException);
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
HuggingFaceService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -119,7 +126,6 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -136,6 +142,7 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
if (maxInputTokens != null) { if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens); builder.field(MAX_INPUT_TOKENS, maxInputTokens);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -24,13 +25,14 @@ public class HuggingFaceElserModel extends HuggingFaceModel {
TaskType taskType, TaskType taskType,
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets @Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
service, service,
HuggingFaceElserServiceSettings.fromMap(serviceSettings), HuggingFaceElserServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets) DefaultSecretSettings.fromMap(secrets)
); );
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
@ -38,10 +39,11 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
TaskType taskType, TaskType taskType,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secretSettings, @Nullable Map<String, Object> secretSettings,
String failureMessage String failureMessage,
ConfigurationParseContext context
) { ) {
return switch (taskType) { return switch (taskType) {
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}; };
} }

View file

@ -15,7 +15,9 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -40,10 +42,16 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
// 3000 requests per minute // 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map) { public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
var uri = extractUri(map, URL, validationException); var uri = extractUri(map, URL, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
HuggingFaceService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -93,7 +101,6 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -103,6 +110,7 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(URL, uri.toString()); builder.field(URL, uri.toString());
builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT); builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT);
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -25,13 +26,14 @@ public class HuggingFaceEmbeddingsModel extends HuggingFaceModel {
TaskType taskType, TaskType taskType,
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets @Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
service, service,
HuggingFaceServiceSettings.fromMap(serviceSettings), HuggingFaceServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets) DefaultSecretSettings.fromMap(secrets)
); );
} }

View file

@ -18,6 +18,7 @@ import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -59,7 +60,13 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
ModelConfigurations.SERVICE_SETTINGS, ModelConfigurations.SERVICE_SETTINGS,
validationException validationException
); );
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
MistralService.NAME,
context
);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
@ -141,7 +148,6 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
this.toXContentFragmentOfExposedFields(builder, params); this.toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -159,6 +165,7 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
if (this.maxInputTokens != null) { if (this.maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, this.maxInputTokens); builder.field(MAX_INPUT_TOKENS, this.maxInputTokens);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -138,7 +138,8 @@ public class OpenAiService extends SenderService {
NAME, NAME,
serviceSettings, serviceSettings,
taskSettings, taskSettings,
secretSettings secretSettings,
context
); );
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}; };

View file

@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; import org.elasticsearch.xpack.inference.services.openai.OpenAiModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
@ -35,13 +36,14 @@ public class OpenAiChatCompletionModel extends OpenAiModel {
String service, String service,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets @Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) { ) {
this( this(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
service, service,
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings), OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
OpenAiChatCompletionTaskSettings.fromMap(taskSettings), OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
DefaultSecretSettings.fromMap(secrets) DefaultSecretSettings.fromMap(secrets)
); );

View file

@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -47,7 +49,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
// 500 requests per minute // 500 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map) { public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -58,7 +60,13 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
OpenAiService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -142,7 +150,6 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -163,6 +170,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
if (maxInputTokens != null) { if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens); builder.field(MAX_INPUT_TOKENS, maxInputTokens);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -20,6 +20,7 @@ import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -66,7 +67,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
// passed at that time and never throw. // passed at that time and never throw.
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
var commonFields = fromMap(map, validationException); var commonFields = fromMap(map, validationException, ConfigurationParseContext.PERSISTENT);
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
if (dimensionsSetByUser == null) { if (dimensionsSetByUser == null) {
@ -80,7 +81,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map) { private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException(); ValidationException validationException = new ValidationException();
var commonFields = fromMap(map, validationException); var commonFields = fromMap(map, validationException, ConfigurationParseContext.REQUEST);
if (validationException.validationErrors().isEmpty() == false) { if (validationException.validationErrors().isEmpty() == false) {
throw validationException; throw validationException;
@ -89,7 +90,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null); return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null);
} }
private static CommonFields fromMap(Map<String, Object> map, ValidationException validationException) { private static CommonFields fromMap(
Map<String, Object> map,
ValidationException validationException,
ConfigurationParseContext context
) {
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException); String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
@ -98,7 +103,13 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
OpenAiService.NAME,
context
);
return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings); return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings);
} }
@ -258,7 +269,6 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
builder.startObject(); builder.startObject();
toXContentFragmentOfExposedFields(builder, params); toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
if (dimensionsSetByUser != null) { if (dimensionsSetByUser != null) {
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
@ -286,6 +296,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
if (maxInputTokens != null) { if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens); builder.field(MAX_INPUT_TOKENS, maxInputTokens);
} }
rateLimitSettings.toXContent(builder, params);
return builder; return builder;
} }

View file

@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
@ -21,19 +22,29 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
public class RateLimitSettings implements Writeable, ToXContentFragment { public class RateLimitSettings implements Writeable, ToXContentFragment {
public static final String FIELD_NAME = "rate_limit"; public static final String FIELD_NAME = "rate_limit";
public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute"; public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute";
private final long requestsPerTimeUnit; private final long requestsPerTimeUnit;
private final TimeUnit timeUnit; private final TimeUnit timeUnit;
public static RateLimitSettings of(Map<String, Object> map, RateLimitSettings defaultValue, ValidationException validationException) { public static RateLimitSettings of(
Map<String, Object> map,
RateLimitSettings defaultValue,
ValidationException validationException,
String serviceName,
ConfigurationParseContext context
) {
Map<String, Object> settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME); Map<String, Object> settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME);
var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException); var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
if (ConfigurationParseContext.isRequestContext(context)) {
throwIfNotEmptyMap(settings, serviceName);
}
return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute); return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute);
} }

View file

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
@ -92,7 +93,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
TruncatorTests.createTruncator() TruncatorTests.createTruncator()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
@ -141,7 +142,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
TruncatorTests.createTruncator() TruncatorTests.createTruncator()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));

View file

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
@ -82,7 +83,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -132,7 +133,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException { public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -183,7 +184,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
// timeout as zero for no retries // timeout as zero for no retries
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -237,7 +238,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
// note - there is no complete documentation on Azure's error messages // note - there is no complete documentation on Azure's error messages
@ -313,7 +314,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
// note - there is no complete documentation on Azure's error messages // note - there is no complete documentation on Azure's error messages
@ -389,7 +390,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testExecute_TruncatesInputBeforeSending() throws IOException { public void testExecute_TruncatesInputBeforeSending() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -440,7 +441,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException { public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -498,7 +499,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException { public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -554,7 +555,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
// timeout as zero for no retries // timeout as zero for no retries
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
// "choices" missing // "choices" missing

View file

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap; import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
@ -77,7 +78,7 @@ public class AzureOpenAiCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
@ -81,7 +82,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
mockClusterServiceEmpty() mockClusterServiceEmpty()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@ -73,7 +74,7 @@ public class CohereActionCreatorTests extends ESTestCase {
public void testCreate_CohereEmbeddingsModel() throws IOException { public void testCreate_CohereEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -154,7 +155,7 @@ public class CohereActionCreatorTests extends ESTestCase {
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException { public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -214,7 +215,7 @@ public class CohereActionCreatorTests extends ESTestCase {
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -77,7 +77,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException { public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -138,7 +138,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException { public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -290,7 +290,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -81,7 +81,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -162,7 +162,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -74,7 +74,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -206,7 +206,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -79,7 +79,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
var input = "input"; var input = "input";
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = senderFactory.createSender("test_service")) { try (var sender = senderFactory.createSender()) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.contains;
@ -75,7 +76,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -131,7 +132,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
); );
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -187,7 +188,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -239,7 +240,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
); );
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
// this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]] // this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]]
@ -292,7 +293,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJsonContentTooLarge = """ String responseJsonContentTooLarge = """
@ -357,7 +358,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
public void testExecute_TruncatesInputBeforeSending() throws IOException { public void testExecute_TruncatesInputBeforeSending() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
@ -74,7 +75,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiEmbeddingsModel() throws IOException { public void testCreate_OpenAiEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -127,7 +128,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -179,7 +180,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException { public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -238,7 +239,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
); );
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -292,7 +293,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiChatCompletionModel() throws IOException { public void testCreate_OpenAiChatCompletionModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -355,7 +356,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException { public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -417,7 +418,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException { public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -486,7 +487,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
); );
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -552,7 +553,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
var contentTooLargeErrorMessage = var contentTooLargeErrorMessage =
@ -635,7 +636,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
var contentTooLargeErrorMessage = var contentTooLargeErrorMessage =
@ -718,7 +719,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
public void testExecute_TruncatesInputBeforeSending() throws IOException { public void testExecute_TruncatesInputBeforeSending() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@ -80,7 +81,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
public void testExecute_ReturnsSuccessfulResponse() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -234,7 +235,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -79,7 +79,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
mockClusterServiceEmpty() mockClusterServiceEmpty()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = senderFactory.createSender()) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """

View file

@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.mock;
public class BaseRequestManagerTests extends ESTestCase {
public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() {
int val1 = 1;
int val2 = 1;
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
assertThat(manager1.rateLimitGrouping(), is(manager2.rateLimitGrouping()));
}
public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() {
int val1 = 1;
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
}
public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() {
int val1 = 1;
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) {
@Override
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
}
};
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
}
}

View file

@ -79,7 +79,7 @@ public class HttpRequestSenderTests extends ESTestCase {
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = createSenderFactory(clientManager, threadRef); var senderFactory = createSenderFactory(clientManager, threadRef);
try (var sender = senderFactory.createSender("test_service")) { try (var sender = createSender(senderFactory)) {
sender.start(); sender.start();
String responseJson = """ String responseJson = """
@ -135,11 +135,11 @@ public class HttpRequestSenderTests extends ESTestCase {
mockClusterServiceEmpty() mockClusterServiceEmpty()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = senderFactory.createSender()) {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var thrownException = expectThrows( var thrownException = expectThrows(
AssertionError.class, AssertionError.class,
() -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) () -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
); );
assertThat(thrownException.getMessage(), is("call start() before sending a request")); assertThat(thrownException.getMessage(), is("call start() before sending a request"));
} }
@ -155,17 +155,12 @@ public class HttpRequestSenderTests extends ESTestCase {
mockClusterServiceEmpty() mockClusterServiceEmpty()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = senderFactory.createSender()) {
assertThat(sender, instanceOf(HttpRequestSender.class)); assertThat(sender, instanceOf(HttpRequestSender.class));
sender.start(); sender.start();
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.send( sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
ExecutableRequestCreatorTests.createMock(),
new DocumentsOnlyInput(List.of()),
TimeValue.timeValueNanos(1),
listener
);
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
@ -186,16 +181,11 @@ public class HttpRequestSenderTests extends ESTestCase {
mockClusterServiceEmpty() mockClusterServiceEmpty()
); );
try (var sender = senderFactory.createSender("test_service")) { try (var sender = senderFactory.createSender()) {
sender.start(); sender.start();
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
sender.send( sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
ExecutableRequestCreatorTests.createMock(),
new DocumentsOnlyInput(List.of()),
TimeValue.timeValueNanos(1),
listener
);
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
@ -220,6 +210,7 @@ public class HttpRequestSenderTests extends ESTestCase {
when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService); when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService);
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class)); when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class));
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
return new HttpRequestSender.Factory( return new HttpRequestSender.Factory(
ServiceComponentsTests.createWithEmptySettings(mockThreadPool), ServiceComponentsTests.createWithEmptySettings(mockThreadPool),
@ -248,7 +239,7 @@ public class HttpRequestSenderTests extends ESTestCase {
); );
} }
public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) { public static Sender createSender(HttpRequestSender.Factory factory) {
return factory.createSender(serviceName); return factory.createSender();
} }
} }

View file

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
@ -18,12 +19,23 @@ public class RequestExecutorServiceSettingsTests {
} }
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) { public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) {
return createRequestExecutorServiceSettings(queueCapacity, null);
}
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(
@Nullable Integer queueCapacity,
@Nullable TimeValue staleDuration
) {
var settingsBuilder = Settings.builder(); var settingsBuilder = Settings.builder();
if (queueCapacity != null) { if (queueCapacity != null) {
settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity); settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity);
} }
if (staleDuration != null) {
settingsBuilder.put(RequestExecutorServiceSettings.RATE_LIMIT_GROUP_STALE_DURATION_SETTING.getKey(), staleDuration);
}
return createRequestExecutorServiceSettings(settingsBuilder.build()); return createRequestExecutorServiceSettings(settingsBuilder.build());
} }

View file

@ -18,13 +18,19 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.RateLimiter;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import java.io.IOException; import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -42,10 +48,13 @@ import static org.elasticsearch.xpack.inference.external.http.sender.RequestExec
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class RequestExecutorServiceTests extends ESTestCase { public class RequestExecutorServiceTests extends ESTestCase {
@ -70,7 +79,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
public void testQueueSize_IsOne() { public void testQueueSize_IsOne() {
var service = createRequestExecutorServiceWithMocks(); var service = createRequestExecutorServiceWithMocks();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
assertThat(service.queueSize(), is(1)); assertThat(service.queueSize(), is(1));
} }
@ -92,7 +101,20 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
} }
public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { public void testCallingStartTwice_ThrowsAssertionException() throws InterruptedException {
var latch = new CountDownLatch(1);
var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class));
service.shutdown();
service.start();
latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
assertTrue(service.isTerminated());
var exception = expectThrows(AssertionError.class, service::start);
assertThat(exception.getMessage(), is("start() can only be called once"));
}
public void testIsTerminated_AfterStopFromSeparateThread() {
var waitToShutdown = new CountDownLatch(1); var waitToShutdown = new CountDownLatch(1);
var waitToReturnFromSend = new CountDownLatch(1); var waitToReturnFromSend = new CountDownLatch(1);
@ -127,41 +149,48 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
} }
public void testSend_AfterShutdown_Throws() { public void testExecute_AfterShutdown_Throws() {
var service = createRequestExecutorServiceWithMocks(); var service = createRequestExecutorServiceWithMocks();
service.shutdown(); service.shutdown();
var requestManager = RequestManagerTests.createMock("id");
var listener = new PlainActionFuture<InferenceServiceResults>(); var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
is("Failed to enqueue task because the http executor service [test_service] has already shutdown") is(
Strings.format(
"Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown",
requestManager.rateLimitGrouping().hashCode()
)
)
); );
assertTrue(thrownException.isExecutorShutdown()); assertTrue(thrownException.isExecutorShutdown());
} }
public void testSend_Throws_WhenQueueIsFull() { public void testExecute_Throws_WhenQueueIsFull() {
var service = new RequestExecutorService( var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class));
"test_service",
threadPool,
null,
createRequestExecutorServiceSettings(1),
new SingleRequestManager(mock(RetryingHttpSender.class))
);
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
var requestManager = RequestManagerTests.createMock("id");
var listener = new PlainActionFuture<InferenceServiceResults>(); var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full") is(
Strings.format(
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
requestManager.rateLimitGrouping().hashCode()
)
)
); );
assertFalse(thrownException.isExecutorShutdown()); assertFalse(thrownException.isExecutorShutdown());
} }
@ -203,16 +232,11 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertTrue(service.isShutdown()); assertTrue(service.isShutdown());
} }
public void testSend_CallsOnFailure_WhenRequestTimesOut() { public void testExecute_CallsOnFailure_WhenRequestTimesOut() {
var service = createRequestExecutorServiceWithMocks(); var service = createRequestExecutorServiceWithMocks();
var listener = new PlainActionFuture<InferenceServiceResults>(); var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute( service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
ExecutableRequestCreatorTests.createMock(),
new DocumentsOnlyInput(List.of()),
TimeValue.timeValueNanos(1),
listener
);
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
@ -222,7 +246,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
); );
} }
public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException { public void testExecute_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
var headerKey = "not empty"; var headerKey = "not empty";
var headerValue = "value"; var headerValue = "value";
@ -270,7 +294,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
} }
}; };
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
Future<?> executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); Future<?> executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service);
@ -280,11 +304,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
} }
public void testSend_NotifiesTasksOfShutdown() { public void testExecute_NotifiesTasksOfShutdown() {
var service = createRequestExecutorServiceWithMocks(); var service = createRequestExecutorServiceWithMocks();
var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id");
var listener = new PlainActionFuture<InferenceServiceResults>(); var listener = new PlainActionFuture<InferenceServiceResults>();
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
service.shutdown(); service.shutdown();
service.start(); service.start();
@ -293,47 +318,62 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
is("Failed to send request, queue service [test_service] has shutdown prior to executing request") is(
Strings.format(
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
requestManager.rateLimitGrouping().hashCode()
)
)
); );
assertTrue(thrownException.isExecutorShutdown()); assertTrue(thrownException.isExecutorShutdown());
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
} }
public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class); BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
var requestSender = mock(RetryingHttpSender.class);
var service = new RequestExecutorService( var service = new RequestExecutorService(
getTestName(),
threadPool, threadPool,
mockQueueCreator(queue), mockQueueCreator(queue),
null, null,
createRequestExecutorServiceSettingsEmpty(), createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(mock(RetryingHttpSender.class)) requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
); );
when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var requestManager = RequestManagerTests.createMock(requestSender, "id");
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
service.shutdown(); service.shutdown();
return null; return null;
}); });
service.start(); service.start();
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
verify(queue, times(2)).take();
} }
public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception { public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class); BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
when(queue.take()).thenThrow(new InterruptedException("failed")); var sleeper = mock(RequestExecutorService.Sleeper.class);
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
var service = new RequestExecutorService( var service = new RequestExecutorService(
getTestName(),
threadPool, threadPool,
mockQueueCreator(queue), mockQueueCreator(queue),
null, null,
createRequestExecutorServiceSettingsEmpty(), createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(mock(RetryingHttpSender.class)) mock(RetryingHttpSender.class),
Clock.systemUTC(),
sleeper,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
); );
Future<?> executorTermination = threadPool.generic().submit(() -> { Future<?> executorTermination = threadPool.generic().submit(() -> {
@ -347,66 +387,30 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
verify(queue, times(1)).take();
}
public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception {
var mockTask = mock(RejectableTask.class);
@SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
var service = new RequestExecutorService(
"test_service",
threadPool,
mockQueueCreator(queue),
null,
createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(mock(RetryingHttpSender.class))
);
doAnswer(invocation -> {
service.shutdown();
return mockTask;
}).doReturn(new NoopTask()).when(queue).take();
service.start();
assertTrue(service.isTerminated());
verify(queue, times(1)).take();
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
verify(mockTask, times(1)).onRejection(argument.capture());
assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class));
assertThat(
argument.getValue().getMessage(),
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
);
var rejectionException = (EsRejectedExecutionException) argument.getValue();
assertTrue(rejectionException.isExecutorShutdown());
} }
public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException { public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException {
var requestSender = mock(RetryingHttpSender.class); var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1); var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); var service = new RequestExecutorService(threadPool, null, settings, requestSender);
service.execute( service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
ExecutableRequestCreatorTests.createMock(requestSender),
new DocumentsOnlyInput(List.of()),
null,
new PlainActionFuture<>()
);
assertThat(service.queueSize(), is(1)); assertThat(service.queueSize(), is(1));
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); var requestManager = RequestManagerTests.createMock(requestSender, "id");
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full") is(
Strings.format(
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
requestManager.rateLimitGrouping().hashCode()
)
)
); );
settings.setQueueCapacity(2); settings.setQueueCapacity(2);
@ -426,7 +430,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
assertThat(service.remainingQueueCapacity(), is(2)); assertThat(service.remainingQueueCapacity(requestManager), is(2));
} }
public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException, public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException,
@ -434,23 +438,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
var requestSender = mock(RetryingHttpSender.class); var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(3); var settings = createRequestExecutorServiceSettings(3);
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); var service = new RequestExecutorService(threadPool, null, settings, requestSender);
service.execute( service.execute(
ExecutableRequestCreatorTests.createMock(requestSender), RequestManagerTests.createMock(requestSender, "id"),
new DocumentsOnlyInput(List.of()), new DocumentsOnlyInput(List.of()),
null, null,
new PlainActionFuture<>() new PlainActionFuture<>()
); );
service.execute( service.execute(
ExecutableRequestCreatorTests.createMock(requestSender), RequestManagerTests.createMock(requestSender, "id"),
new DocumentsOnlyInput(List.of()), new DocumentsOnlyInput(List.of()),
null, null,
new PlainActionFuture<>() new PlainActionFuture<>()
); );
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); var requestManager = RequestManagerTests.createMock(requestSender, "id");
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
assertThat(service.queueSize(), is(3)); assertThat(service.queueSize(), is(3));
settings.setQueueCapacity(1); settings.setQueueCapacity(1);
@ -470,7 +475,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
assertThat(service.remainingQueueCapacity(), is(1)); assertThat(service.remainingQueueCapacity(requestManager), is(1));
assertThat(service.queueSize(), is(0)); assertThat(service.queueSize(), is(0));
var thrownException = expectThrows( var thrownException = expectThrows(
@ -479,7 +484,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
); );
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
is("Failed to send request, queue service [test_service] has shutdown prior to executing request") is(
Strings.format(
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
requestManager.rateLimitGrouping().hashCode()
)
)
); );
assertTrue(thrownException.isExecutorShutdown()); assertTrue(thrownException.isExecutorShutdown());
} }
@ -489,23 +499,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
var requestSender = mock(RetryingHttpSender.class); var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1); var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); var service = new RequestExecutorService(threadPool, null, settings, requestSender);
var requestManager = RequestManagerTests.createMock(requestSender);
service.execute( service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
ExecutableRequestCreatorTests.createMock(requestSender),
new DocumentsOnlyInput(List.of()),
null,
new PlainActionFuture<>()
);
assertThat(service.queueSize(), is(1)); assertThat(service.queueSize(), is(1));
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>(); PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); service.execute(RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, listener);
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
is("Failed to execute task because the http executor service [test_service] queue is full") is(
Strings.format(
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
requestManager.rateLimitGrouping().hashCode()
)
)
); );
settings.setQueueCapacity(0); settings.setQueueCapacity(0);
@ -525,7 +536,133 @@ public class RequestExecutorServiceTests extends ESTestCase {
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
assertTrue(service.isTerminated()); assertTrue(service.isTerminated());
assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE)); assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE));
}
public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
var mockRateLimiter = mock(RateLimiter.class);
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService(
threadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
doAnswer(invocation -> {
service.shutdown();
return TimeValue.timeValueDays(1);
}).when(mockRateLimiter).timeToReserve(anyInt());
service.start();
verifyNoInteractions(requestSender);
}
public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() {
var mockRateLimiter = mock(RateLimiter.class);
when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0));
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(1);
var service = new RequestExecutorService(
threadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0));
doAnswer(invocation -> {
service.shutdown();
return Void.TYPE;
}).when(requestSender).send(any(), any(), any(), any(), any(), any());
service.start();
verify(requestSender, times(1)).send(any(), any(), any(), any(), any(), any());
}
public void testRemovesRateLimitGroup_AfterStaleDuration() {
var now = Instant.now();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(now);
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
var service = new RequestExecutorService(
threadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
clock,
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
assertThat(service.numberOfRateLimitGroups(), is(1));
// the time is moved to after the stale duration, so now we should remove this grouping
when(clock.instant()).thenReturn(now.plus(Duration.ofDays(2)));
service.removeStaleGroupings();
assertThat(service.numberOfRateLimitGroups(), is(0));
var requestManager2 = RequestManagerTests.createMock(requestSender, "id2");
service.execute(requestManager2, new DocumentsOnlyInput(List.of()), null, listener);
assertThat(service.numberOfRateLimitGroups(), is(1));
}
public void testStartsCleanupThread() {
var mockThreadPool = mock(ThreadPool.class);
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
var requestSender = mock(RetryingHttpSender.class);
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
var service = new RequestExecutorService(
mockThreadPool,
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
null,
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
service.shutdown();
service.start();
ArgumentCaptor<TimeValue> argument = ArgumentCaptor.forClass(TimeValue.class);
verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), argument.capture(), any());
assertThat(argument.getValue(), is(TimeValue.timeValueDays(1)));
} }
private Future<?> submitShutdownRequest( private Future<?> submitShutdownRequest(
@ -552,12 +689,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
} }
private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) { private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) {
return new RequestExecutorService( return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender);
"test_service",
threadPool,
startupLatch,
createRequestExecutorServiceSettingsEmpty(),
new SingleRequestManager(requestSender)
);
} }
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.RequestTests; import org.elasticsearch.xpack.inference.external.request.RequestTests;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyList;
@ -21,34 +22,47 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class ExecutableRequestCreatorTests { public class RequestManagerTests {
public static RequestManager createMock() { public static RequestManager createMock() {
var mockCreator = mock(RequestManager.class); return createMock(mock(RequestSender.class));
when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {}); }
return mockCreator; public static RequestManager createMock(String inferenceEntityId) {
return createMock(mock(RequestSender.class), inferenceEntityId);
} }
public static RequestManager createMock(RequestSender requestSender) { public static RequestManager createMock(RequestSender requestSender) {
return createMock(requestSender, "id"); return createMock(requestSender, "id", new RateLimitSettings(1));
} }
public static RequestManager createMock(RequestSender requestSender, String modelId) { public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) {
var mockCreator = mock(RequestManager.class); return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1));
}
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) {
var mockManager = mock(RequestManager.class);
doAnswer(invocation -> { doAnswer(invocation -> {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[5]; ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[4];
return (Runnable) () -> requestSender.send( requestSender.send(
mock(Logger.class), mock(Logger.class),
RequestTests.mockRequest(modelId), RequestTests.mockRequest(inferenceEntityId),
HttpClientContext.create(), HttpClientContext.create(),
() -> false, () -> false,
mock(ResponseHandler.class), mock(ResponseHandler.class),
listener listener
); );
}).when(mockCreator).create(any(), anyList(), any(), any(), any(), any());
return mockCreator; return Void.TYPE;
}).when(mockManager).execute(any(), anyList(), any(), any(), any());
// just return something consistent so the hashing works
when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId);
when(mockManager.rateLimitSettings()).thenReturn(settings);
when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId);
return mockManager;
} }
} }

View file

@ -1,27 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
public class SingleRequestManagerTests extends ESTestCase {
public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() {
var requestCreator = mock(RequestManager.class);
var request = mock(InferenceRequest.class);
when(request.getRequestCreator()).thenReturn(requestCreator);
new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create());
verifyNoInteractions(requestCreator);
}
}

View file

@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -59,7 +58,7 @@ public class SenderServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>(); PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
@ -67,7 +66,7 @@ public class SenderServiceTests extends ESTestCase {
listener.actionGet(TIMEOUT); listener.actionGet(TIMEOUT);
verify(sender, times(1)).start(); verify(sender, times(1)).start();
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
} }
verify(sender, times(1)).close(); verify(sender, times(1)).close();
@ -79,7 +78,7 @@ public class SenderServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>(); PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
@ -89,7 +88,7 @@ public class SenderServiceTests extends ESTestCase {
service.start(mock(Model.class), listener); service.start(mock(Model.class), listener);
listener.actionGet(TIMEOUT); listener.actionGet(TIMEOUT);
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(2)).start(); verify(sender, times(2)).start();
} }

View file

@ -76,7 +76,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -819,7 +818,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -841,7 +840,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }

View file

@ -112,7 +112,8 @@ public class AzureAiStudioChatCompletionServiceSettingsTests extends ESTestCase
String xContentResult = Strings.toString(builder); String xContentResult = Strings.toString(builder);
assertThat(xContentResult, CoreMatchers.is(""" assertThat(xContentResult, CoreMatchers.is("""
{"target":"target_value","provider":"openai","endpoint_type":"token"}""")); {"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
"rate_limit":{"requests_per_minute":3}}"""));
} }
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) { public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {

View file

@ -295,7 +295,7 @@ public class AzureAiStudioEmbeddingsServiceSettingsTests extends ESTestCase {
assertThat(xContentResult, CoreMatchers.is(""" assertThat(xContentResult, CoreMatchers.is("""
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """ {"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
"dimensions":1024,"max_input_tokens":512}""")); "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}"""));
} }
public static HashMap<String, Object> createRequestSettingsMap( public static HashMap<String, Object> createRequestSettingsMap(

View file

@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -594,7 +593,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -616,7 +615,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }

View file

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
import java.io.IOException; import java.io.IOException;
@ -46,7 +47,8 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
AzureOpenAiServiceFields.API_VERSION, AzureOpenAiServiceFields.API_VERSION,
apiVersion apiVersion
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null)));
@ -63,18 +65,6 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}""")); {"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}"""));
} }
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}"""));
}
@Override @Override
protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() { protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() {
return AzureOpenAiCompletionServiceSettings::new; return AzureOpenAiCompletionServiceSettings::new;

View file

@ -389,7 +389,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}""")); "dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}"""));
} }
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_RateLimit() throws IOException { public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser() throws IOException {
var entity = new AzureOpenAiEmbeddingsServiceSettings( var entity = new AzureOpenAiEmbeddingsServiceSettings(
"resource", "resource",
"deployment", "deployment",
@ -408,7 +408,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
assertThat(xContentResult, is(""" assertThat(xContentResult, is("""
{"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
"dimensions":1024,"max_input_tokens":512}""")); "dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":1}}"""));
} }
@Override @Override

View file

@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -613,7 +612,7 @@ public class CohereServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -635,7 +634,7 @@ public class CohereServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }

View file

@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.HashMap; import java.util.HashMap;
@ -28,7 +29,8 @@ public class CohereCompletionModelTests extends ESTestCase {
"service", "service",
new HashMap<>(Map.of()), new HashMap<>(Map.of()),
new HashMap<>(Map.of("model", "overridden model")), new HashMap<>(Map.of("model", "overridden model")),
null null,
ConfigurationParseContext.PERSISTENT
); );
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));

View file

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
@ -34,7 +35,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
var model = "model"; var model = "model";
var serviceSettings = CohereCompletionServiceSettings.fromMap( var serviceSettings = CohereCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)) new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)),
ConfigurationParseContext.PERSISTENT
); );
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null))); assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null)));
@ -55,7 +57,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
RateLimitSettings.FIELD_NAME, RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)) new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute)))); assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute))));
@ -72,18 +75,6 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
{"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}""")); {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}"""));
} }
public void testToXContent_WithFilteredObject_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url","model_id":"model"}"""));
}
@Override @Override
protected Writeable.Reader<CohereCompletionServiceSettings> instanceReader() { protected Writeable.Reader<CohereCompletionServiceSettings> instanceReader() {
return CohereCompletionServiceSettings::new; return CohereCompletionServiceSettings::new;

View file

@ -331,21 +331,6 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
"rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}""")); "rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}"""));
} }
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new CohereEmbeddingsServiceSettings(
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)),
CohereEmbeddingType.INT8
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """
"embedding_type":"byte"}"""));
}
@Override @Override
protected Writeable.Reader<CohereEmbeddingsServiceSettings> instanceReader() { protected Writeable.Reader<CohereEmbeddingsServiceSettings> instanceReader() {
return CohereEmbeddingsServiceSettings::new; return CohereEmbeddingsServiceSettings::new;

View file

@ -51,20 +51,6 @@ public class CohereRerankServiceSettingsTests extends AbstractWireSerializingTes
"rate_limit":{"requests_per_minute":3}}""")); "rate_limit":{"requests_per_minute":3}}"""));
} }
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new CohereRerankServiceSettings(
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3))
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
// TODO we probably shouldn't allow configuring these fields for reranking
assertThat(xContentResult, is("""
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id"}"""));
}
@Override @Override
protected Writeable.Reader<CohereRerankServiceSettings> instanceReader() { protected Writeable.Reader<CohereRerankServiceSettings> instanceReader() {
return CohereRerankServiceSettings::new; return CohereRerankServiceSettings::new;

View file

@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -494,7 +493,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -516,7 +515,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }

View file

@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.net.URISyntaxException; import java.net.URISyntaxException;
@ -28,7 +29,8 @@ public class GoogleAiStudioCompletionModelTests extends ESTestCase {
"service", "service",
new HashMap<>(Map.of("model_id", "model")), new HashMap<>(Map.of("model_id", "model")),
new HashMap<>(Map.of()), new HashMap<>(Map.of()),
null null,
ConfigurationParseContext.PERSISTENT
); );
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));

View file

@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
@ -31,7 +32,10 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
public void testFromMap_Request_CreatesSettingsCorrectly() { public void testFromMap_Request_CreatesSettingsCorrectly() {
var model = "some model"; var model = "some model";
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model))); var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null))); assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null)));
} }
@ -47,18 +51,6 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
{"model_id":"model","rate_limit":{"requests_per_minute":360}}""")); {"model_id":"model","rate_limit":{"requests_per_minute":360}}"""));
} }
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
var entity = new GoogleAiStudioCompletionServiceSettings("model", null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"model_id":"model"}"""));
}
@Override @Override
protected Writeable.Reader<GoogleAiStudioCompletionServiceSettings> instanceReader() { protected Writeable.Reader<GoogleAiStudioCompletionServiceSettings> instanceReader() {
return GoogleAiStudioCompletionServiceSettings::new; return GoogleAiStudioCompletionServiceSettings::new;

View file

@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
@ -55,7 +56,8 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
ServiceFields.SIMILARITY, ServiceFields.SIMILARITY,
similarity.toString() similarity.toString()
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null))); assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null)));
@ -80,23 +82,6 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
}""")); }"""));
} }
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model_id":"model",
"max_input_tokens": 1024,
"dimensions": 8,
"similarity": "dot_product"
}"""));
}
@Override @Override
protected Writeable.Reader<GoogleAiStudioEmbeddingsServiceSettings> instanceReader() { protected Writeable.Reader<GoogleAiStudioEmbeddingsServiceSettings> instanceReader() {
return GoogleAiStudioEmbeddingsServiceSettings::new; return GoogleAiStudioEmbeddingsServiceSettings::new;

View file

@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -33,7 +34,6 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -59,7 +59,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -81,7 +81,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }
@ -111,7 +111,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
TaskType taskType, TaskType taskType,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> secretSettings, Map<String, Object> secretSettings,
String failureMessage String failureMessage,
ConfigurationParseContext context
) { ) {
return null; return null;
} }

View file

@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -57,7 +58,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
var dims = 384; var dims = 384;
var maxInputTokens = 128; var maxInputTokens = 128;
{ {
var serviceSettings = HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))); var serviceSettings = HuggingFaceServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url))); assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url)));
} }
{ {
@ -73,7 +77,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
ServiceFields.MAX_INPUT_TOKENS, ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens maxInputTokens
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat( assertThat(
serviceSettings, serviceSettings,
@ -95,7 +100,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
RateLimitSettings.FIELD_NAME, RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat( assertThat(
serviceSettings, serviceSettings,
@ -105,7 +111,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
} }
public void testFromMap_MissingUrl_ThrowsError() { public void testFromMap_MissingUrl_ThrowsError() {
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceServiceSettings.fromMap(new HashMap<>())); var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
);
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
@ -118,7 +127,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
public void testFromMap_EmptyUrl_ThrowsError() { public void testFromMap_EmptyUrl_ThrowsError() {
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, ""))) () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT)
); );
assertThat( assertThat(
@ -136,7 +145,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
var url = "https://www.abc^.com"; var url = "https://www.abc^.com";
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))) () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT)
); );
assertThat( assertThat(
@ -152,7 +161,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
var similarity = "by_size"; var similarity = "by_size";
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity))) () -> HuggingFaceServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)),
ConfigurationParseContext.PERSISTENT
)
); );
assertThat( assertThat(
@ -175,18 +187,6 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
{"url":"url","rate_limit":{"requests_per_minute":3}}""")); {"url":"url","rate_limit":{"requests_per_minute":3}}"""));
} }
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new HuggingFaceServiceSettings(ServiceUtils.createUri("url"), null, null, null, new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url"}"""));
}
@Override @Override
protected Writeable.Reader<HuggingFaceServiceSettings> instanceReader() { protected Writeable.Reader<HuggingFaceServiceSettings> instanceReader() {
return HuggingFaceServiceSettings::new; return HuggingFaceServiceSettings::new;

View file

@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -32,7 +33,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
public void testFromMap() { public void testFromMap() {
var url = "https://www.abc.com"; var url = "https://www.abc.com";
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url))); var serviceSettings = HuggingFaceElserServiceSettings.fromMap(
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
ConfigurationParseContext.PERSISTENT
);
assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings)); assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings));
} }
@ -40,7 +44,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
public void testFromMap_EmptyUrl_ThrowsError() { public void testFromMap_EmptyUrl_ThrowsError() {
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, ""))) () -> HuggingFaceElserServiceSettings.fromMap(
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")),
ConfigurationParseContext.PERSISTENT
)
); );
assertThat( assertThat(
@ -55,7 +62,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
} }
public void testFromMap_MissingUrl_ThrowsError() { public void testFromMap_MissingUrl_ThrowsError() {
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>())); var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
);
assertThat( assertThat(
thrownException.getMessage(), thrownException.getMessage(),
@ -72,7 +82,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
var url = "https://www.abc^.com"; var url = "https://www.abc^.com";
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url))) () -> HuggingFaceElserServiceSettings.fromMap(
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
ConfigurationParseContext.PERSISTENT
)
); );
assertThat( assertThat(
@ -98,18 +111,6 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
{"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}""")); {"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}"""));
} }
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new HuggingFaceElserServiceSettings(ServiceUtils.createUri("url"), new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
assertThat(xContentResult, is("""
{"url":"url","max_input_tokens":512}"""));
}
@Override @Override
protected Writeable.Reader<HuggingFaceElserServiceSettings> instanceReader() { protected Writeable.Reader<HuggingFaceElserServiceSettings> instanceReader() {
return HuggingFaceElserServiceSettings::new; return HuggingFaceElserServiceSettings::new;

View file

@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -393,7 +392,7 @@ public class MistralServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -415,7 +414,7 @@ public class MistralServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }

View file

@ -98,18 +98,6 @@ public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
"rate_limit":{"requests_per_minute":3}}""")); "rate_limit":{"requests_per_minute":3}}"""));
} }
public void testToFilteredXContent_WritesFilteredValues() throws IOException {
var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = entity.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, CoreMatchers.is("""
{"model":"model_name","dimensions":1024,"max_input_tokens":512}"""));
}
public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
var outputBuffer = new BytesStreamOutput(); var outputBuffer = new BytesStreamOutput();
var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));

View file

@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -675,7 +674,7 @@ public class OpenAiServiceTests extends ESTestCase {
var sender = mock(Sender.class); var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class); var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender(anyString())).thenReturn(sender); when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name"); var mockModel = getInvalidModel("model_id", "service_name");
@ -697,7 +696,7 @@ public class OpenAiServiceTests extends ESTestCase {
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
); );
verify(factory, times(1)).createSender(anyString()); verify(factory, times(1)).createSender();
verify(sender, times(1)).start(); verify(sender, times(1)).start();
} }

View file

@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
@ -48,7 +49,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
ServiceFields.MAX_INPUT_TOKENS, ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens maxInputTokens
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat( assertThat(
@ -77,7 +79,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
RateLimitSettings.FIELD_NAME, RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit)) new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit))
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertThat( assertThat(
@ -101,7 +104,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
ServiceFields.MAX_INPUT_TOKENS, ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens maxInputTokens
) )
) ),
ConfigurationParseContext.PERSISTENT
); );
assertNull(serviceSettings.uri()); assertNull(serviceSettings.uri());
@ -113,7 +117,10 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
public void testFromMap_EmptyUrl_ThrowsError() { public void testFromMap_EmptyUrl_ThrowsError() {
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> OpenAiChatCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model"))) () -> OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")),
ConfigurationParseContext.PERSISTENT
)
); );
assertThat( assertThat(
@ -132,7 +139,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
var maxInputTokens = 8192; var maxInputTokens = 8192;
var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap( var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)) new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)),
ConfigurationParseContext.PERSISTENT
); );
assertNull(serviceSettings.uri()); assertNull(serviceSettings.uri());
@ -144,7 +152,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> OpenAiChatCompletionServiceSettings.fromMap( () -> OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")) new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")),
ConfigurationParseContext.PERSISTENT
) )
); );
@ -164,7 +173,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
var thrownException = expectThrows( var thrownException = expectThrows(
ValidationException.class, ValidationException.class,
() -> OpenAiChatCompletionServiceSettings.fromMap( () -> OpenAiChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")) new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")),
ConfigurationParseContext.PERSISTENT
) )
); );
@ -213,19 +223,6 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
{"model_id":"model","rate_limit":{"requests_per_minute":500}}""")); {"model_id":"model","rate_limit":{"requests_per_minute":500}}"""));
} }
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
var serviceSettings = new OpenAiChatCompletionServiceSettings("model", "url", "org", 1024, new RateLimitSettings(2));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
var filteredXContent = serviceSettings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
assertThat(xContentResult, is("""
{"model_id":"model","url":"url","organization_id":"org",""" + """
"max_input_tokens":1024}"""));
}
@Override @Override
protected Writeable.Reader<OpenAiChatCompletionServiceSettings> instanceReader() { protected Writeable.Reader<OpenAiChatCompletionServiceSettings> instanceReader() {
return OpenAiChatCompletionServiceSettings::new; return OpenAiChatCompletionServiceSettings::new;

Some files were not shown because too many files have changed in this diff Show more