mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-25 07:37:19 -04:00
[ML] Inference API rate limit queuing logic refactor (#107706)
* Adding new executor * Adding in queuing logic * working tests * Added cleanup task * Update docs/changelog/107706.yaml * Updating yml * deregistering callbacks for settings changes * Cleaning up code * Update docs/changelog/107706.yaml * Fixing rate limit settings bug and only sleeping least amount * Removing debug logging * Removing commented code * Renaming feedback * fixing tests * Updating docs and validation * Fixing source blocks * Adjusting cancel logic * Reformatting ascii * Addressing feedback * adding rate limiting for google embeddings and mistral --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
cd84749d87
commit
fdb5058b13
102 changed files with 1499 additions and 937 deletions
5
docs/changelog/107706.yaml
Normal file
5
docs/changelog/107706.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 107706
|
||||||
|
summary: Add rate limiting support for the Inference API
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -7,21 +7,17 @@ experimental[]
|
||||||
Creates an {infer} endpoint to perform an {infer} task.
|
Creates an {infer} endpoint to perform an {infer} task.
|
||||||
|
|
||||||
IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in
|
IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in
|
||||||
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure
|
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face.
|
||||||
OpenAI, Google AI Studio or Hugging Face. For built-in models and models
|
For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models.
|
||||||
uploaded though Eland, the {infer} APIs offer an alternative way to use and
|
However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the
|
||||||
manage trained models. However, if you do not plan to use the {infer} APIs to
|
|
||||||
use these models or if you want to use non-NLP models, use the
|
|
||||||
<<ml-df-trained-models-apis>>.
|
<<ml-df-trained-models-apis>>.
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[put-inference-api-request]]
|
[[put-inference-api-request]]
|
||||||
==== {api-request-title}
|
==== {api-request-title}
|
||||||
|
|
||||||
`PUT /_inference/<task_type>/<inference_id>`
|
`PUT /_inference/<task_type>/<inference_id>`
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[put-inference-api-prereqs]]
|
[[put-inference-api-prereqs]]
|
||||||
==== {api-prereq-title}
|
==== {api-prereq-title}
|
||||||
|
@ -29,7 +25,6 @@ use these models or if you want to use non-NLP models, use the
|
||||||
* Requires the `manage_inference` <<privileges-list-cluster,cluster privilege>>
|
* Requires the `manage_inference` <<privileges-list-cluster,cluster privilege>>
|
||||||
(the built-in `inference_admin` role grants this privilege)
|
(the built-in `inference_admin` role grants this privilege)
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[put-inference-api-desc]]
|
[[put-inference-api-desc]]
|
||||||
==== {api-description-title}
|
==== {api-description-title}
|
||||||
|
@ -48,25 +43,23 @@ The following services are available through the {infer} API:
|
||||||
* Hugging Face
|
* Hugging Face
|
||||||
* OpenAI
|
* OpenAI
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[put-inference-api-path-params]]
|
[[put-inference-api-path-params]]
|
||||||
==== {api-path-parms-title}
|
==== {api-path-parms-title}
|
||||||
|
|
||||||
|
|
||||||
`<inference_id>`::
|
`<inference_id>`::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
The unique identifier of the {infer} endpoint.
|
The unique identifier of the {infer} endpoint.
|
||||||
|
|
||||||
`<task_type>`::
|
`<task_type>`::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
The type of the {infer} task that the model will perform. Available task types:
|
The type of the {infer} task that the model will perform.
|
||||||
|
Available task types:
|
||||||
* `completion`,
|
* `completion`,
|
||||||
* `rerank`,
|
* `rerank`,
|
||||||
* `sparse_embedding`,
|
* `sparse_embedding`,
|
||||||
* `text_embedding`.
|
* `text_embedding`.
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[put-inference-api-request-body]]
|
[[put-inference-api-request-body]]
|
||||||
==== {api-request-body-title}
|
==== {api-request-body-title}
|
||||||
|
@ -78,21 +71,18 @@ Available services:
|
||||||
|
|
||||||
* `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service.
|
* `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service.
|
||||||
* `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service.
|
* `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service.
|
||||||
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the
|
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service.
|
||||||
Cohere service.
|
* `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland.
|
||||||
* `elasticsearch`: specify the `text_embedding` task type to use the E5
|
|
||||||
built-in model or text embedding models uploaded by Eland.
|
|
||||||
* `elser`: specify the `sparse_embedding` task type to use the ELSER service.
|
* `elser`: specify the `sparse_embedding` task type to use the ELSER service.
|
||||||
* `googleaistudio`: specify the `completion` task to use the Google AI Studio service.
|
* `googleaistudio`: specify the `completion` task to use the Google AI Studio service.
|
||||||
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face
|
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service.
|
||||||
service.
|
* `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service.
|
||||||
* `openai`: specify the `completion` or `text_embedding` task type to use the
|
|
||||||
OpenAI service.
|
|
||||||
|
|
||||||
|
|
||||||
`service_settings`::
|
`service_settings`::
|
||||||
(Required, object)
|
(Required, object)
|
||||||
Settings used to install the {infer} model. These settings are specific to the
|
Settings used to install the {infer} model.
|
||||||
|
These settings are specific to the
|
||||||
`service` you specified.
|
`service` you specified.
|
||||||
+
|
+
|
||||||
.`service_settings` for the `azureaistudio` service
|
.`service_settings` for the `azureaistudio` service
|
||||||
|
@ -104,11 +94,10 @@ Settings used to install the {infer} model. These settings are specific to the
|
||||||
A valid API key of your Azure AI Studio model deployment.
|
A valid API key of your Azure AI Studio model deployment.
|
||||||
This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account.
|
This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account.
|
||||||
|
|
||||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
The <<get-inference-api>> does not retrieve your API key.
|
||||||
creating the {infer} model, you cannot change the associated API key. If you
|
After creating the {infer} model, you cannot change the associated API key.
|
||||||
want to use a different API key, delete the {infer} model and recreate it with
|
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||||
the same name and the updated API key.
|
|
||||||
|
|
||||||
`target`:::
|
`target`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
|
@ -142,11 +131,13 @@ For "real-time" endpoints which are billed per hour of usage, specify `realtime`
|
||||||
By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`.
|
By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`.
|
||||||
This helps to minimize the number of rate limit errors returned from Azure AI Studio.
|
This helps to minimize the number of rate limit errors returned from Azure AI Studio.
|
||||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||||
```
|
+
|
||||||
|
[source,text]
|
||||||
|
----
|
||||||
"rate_limit": {
|
"rate_limit": {
|
||||||
"requests_per_minute": <<number_of_requests>>
|
"requests_per_minute": <<number_of_requests>>
|
||||||
}
|
}
|
||||||
```
|
----
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `azureopenai` service
|
.`service_settings` for the `azureopenai` service
|
||||||
|
@ -181,6 +172,22 @@ Your Azure OpenAI deployments can be found though the https://oai.azure.com/[Azu
|
||||||
The Azure API version ID to use.
|
The Azure API version ID to use.
|
||||||
We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version].
|
We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version].
|
||||||
|
|
||||||
|
`rate_limit`:::
|
||||||
|
(Optional, object)
|
||||||
|
The `azureopenai` service sets a default number of requests allowed per minute depending on the task type.
|
||||||
|
For `text_embedding` it is set to `1440`.
|
||||||
|
For `completion` it is set to `120`.
|
||||||
|
This helps to minimize the number of rate limit errors returned from Azure.
|
||||||
|
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||||
|
+
|
||||||
|
[source,text]
|
||||||
|
----
|
||||||
|
"rate_limit": {
|
||||||
|
"requests_per_minute": <<number_of_requests>>
|
||||||
|
}
|
||||||
|
----
|
||||||
|
+
|
||||||
|
More information about the rate limits for Azure can be found in the https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits[Quota limits docs] and https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest[How to change the quotas].
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `cohere` service
|
.`service_settings` for the `cohere` service
|
||||||
|
@ -188,24 +195,24 @@ We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/opena
|
||||||
=====
|
=====
|
||||||
`api_key`:::
|
`api_key`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
A valid API key of your Cohere account. You can find your Cohere API keys or you
|
A valid API key of your Cohere account.
|
||||||
can create a new one
|
You can find your Cohere API keys or you can create a new one
|
||||||
https://dashboard.cohere.com/api-keys[on the API keys settings page].
|
https://dashboard.cohere.com/api-keys[on the API keys settings page].
|
||||||
|
|
||||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
The <<get-inference-api>> does not retrieve your API key.
|
||||||
creating the {infer} model, you cannot change the associated API key. If you
|
After creating the {infer} model, you cannot change the associated API key.
|
||||||
want to use a different API key, delete the {infer} model and recreate it with
|
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||||
the same name and the updated API key.
|
|
||||||
|
|
||||||
`embedding_type`::
|
`embedding_type`::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
Only for `text_embedding`. Specifies the types of embeddings you want to get
|
Only for `text_embedding`.
|
||||||
back. Defaults to `float`.
|
Specifies the types of embeddings you want to get back.
|
||||||
|
Defaults to `float`.
|
||||||
Valid values are:
|
Valid values are:
|
||||||
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
|
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
|
||||||
* `float`: use it for the default float embeddings.
|
* `float`: use it for the default float embeddings.
|
||||||
* `int8`: use it for signed int8 embeddings.
|
* `int8`: use it for signed int8 embeddings.
|
||||||
|
|
||||||
`model_id`::
|
`model_id`::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
|
@ -214,50 +221,68 @@ To review the available `rerank` models, refer to the
|
||||||
https://docs.cohere.com/reference/rerank-1[Cohere docs].
|
https://docs.cohere.com/reference/rerank-1[Cohere docs].
|
||||||
|
|
||||||
To review the available `text_embedding` models, refer to the
|
To review the available `text_embedding` models, refer to the
|
||||||
https://docs.cohere.com/reference/embed[Cohere docs]. The default value for
|
https://docs.cohere.com/reference/embed[Cohere docs].
|
||||||
|
The default value for
|
||||||
`text_embedding` is `embed-english-v2.0`.
|
`text_embedding` is `embed-english-v2.0`.
|
||||||
|
|
||||||
|
`rate_limit`:::
|
||||||
|
(Optional, object)
|
||||||
|
By default, the `cohere` service sets the number of requests allowed per minute to `10000`.
|
||||||
|
This value is the same for all task types.
|
||||||
|
This helps to minimize the number of rate limit errors returned from Cohere.
|
||||||
|
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||||
|
+
|
||||||
|
[source,text]
|
||||||
|
----
|
||||||
|
"rate_limit": {
|
||||||
|
"requests_per_minute": <<number_of_requests>>
|
||||||
|
}
|
||||||
|
----
|
||||||
|
+
|
||||||
|
More information about Cohere's rate limits can be found in https://docs.cohere.com/docs/going-live#production-key-specifications[Cohere's production key docs].
|
||||||
|
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `elasticsearch` service
|
.`service_settings` for the `elasticsearch` service
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`model_id`:::
|
`model_id`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
The name of the model to use for the {infer} task. It can be the
|
The name of the model to use for the {infer} task.
|
||||||
ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or
|
It can be the ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or a text embedding model already
|
||||||
a text embedding model already
|
|
||||||
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
|
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
|
||||||
|
|
||||||
`num_allocations`:::
|
`num_allocations`:::
|
||||||
(Required, integer)
|
(Required, integer)
|
||||||
The number of model allocations to create. `num_allocations` must not exceed the
|
The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
|
||||||
number of available processors per node divided by the `num_threads`.
|
|
||||||
|
|
||||||
`num_threads`:::
|
`num_threads`:::
|
||||||
(Required, integer)
|
(Required, integer)
|
||||||
The number of threads to use by each model allocation. `num_threads` must not
|
The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
|
||||||
exceed the number of available processors per node divided by the number of
|
Must be a power of 2. Max allowed value is 32.
|
||||||
allocations. Must be a power of 2. Max allowed value is 32.
|
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `elser` service
|
.`service_settings` for the `elser` service
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`num_allocations`:::
|
`num_allocations`:::
|
||||||
(Required, integer)
|
(Required, integer)
|
||||||
The number of model allocations to create. `num_allocations` must not exceed the
|
The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
|
||||||
number of available processors per node divided by the `num_threads`.
|
|
||||||
|
|
||||||
`num_threads`:::
|
`num_threads`:::
|
||||||
(Required, integer)
|
(Required, integer)
|
||||||
The number of threads to use by each model allocation. `num_threads` must not
|
The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
|
||||||
exceed the number of available processors per node divided by the number of
|
Must be a power of 2. Max allowed value is 32.
|
||||||
allocations. Must be a power of 2. Max allowed value is 32.
|
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `googleiastudio` service
|
.`service_settings` for the `googleiastudio` service
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`api_key`:::
|
`api_key`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
A valid API key for the Google Gemini API.
|
A valid API key for the Google Gemini API.
|
||||||
|
@ -274,76 +299,113 @@ This helps to minimize the number of rate limit errors returned from Google AI S
|
||||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||||
+
|
+
|
||||||
--
|
--
|
||||||
```
|
[source,text]
|
||||||
|
----
|
||||||
"rate_limit": {
|
"rate_limit": {
|
||||||
"requests_per_minute": <<number_of_requests>>
|
"requests_per_minute": <<number_of_requests>>
|
||||||
}
|
}
|
||||||
```
|
----
|
||||||
--
|
--
|
||||||
|
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `hugging_face` service
|
.`service_settings` for the `hugging_face` service
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`api_key`:::
|
`api_key`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
A valid access token of your Hugging Face account. You can find your Hugging
|
A valid access token of your Hugging Face account.
|
||||||
Face access tokens or you can create a new one
|
You can find your Hugging Face access tokens or you can create a new one
|
||||||
https://huggingface.co/settings/tokens[on the settings page].
|
https://huggingface.co/settings/tokens[on the settings page].
|
||||||
|
|
||||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
The <<get-inference-api>> does not retrieve your API key.
|
||||||
creating the {infer} model, you cannot change the associated API key. If you
|
After creating the {infer} model, you cannot change the associated API key.
|
||||||
want to use a different API key, delete the {infer} model and recreate it with
|
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||||
the same name and the updated API key.
|
|
||||||
|
|
||||||
`url`:::
|
`url`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
The URL endpoint to use for the requests.
|
The URL endpoint to use for the requests.
|
||||||
|
|
||||||
|
`rate_limit`:::
|
||||||
|
(Optional, object)
|
||||||
|
By default, the `huggingface` service sets the number of requests allowed per minute to `3000`.
|
||||||
|
This helps to minimize the number of rate limit errors returned from Hugging Face.
|
||||||
|
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||||
|
+
|
||||||
|
[source,text]
|
||||||
|
----
|
||||||
|
"rate_limit": {
|
||||||
|
"requests_per_minute": <<number_of_requests>>
|
||||||
|
}
|
||||||
|
----
|
||||||
|
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`service_settings` for the `openai` service
|
.`service_settings` for the `openai` service
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`api_key`:::
|
`api_key`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
A valid API key of your OpenAI account. You can find your OpenAI API keys in
|
A valid API key of your OpenAI account.
|
||||||
your OpenAI account under the
|
You can find your OpenAI API keys in your OpenAI account under the
|
||||||
https://platform.openai.com/api-keys[API keys section].
|
https://platform.openai.com/api-keys[API keys section].
|
||||||
|
|
||||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
The <<get-inference-api>> does not retrieve your API key.
|
||||||
creating the {infer} model, you cannot change the associated API key. If you
|
After creating the {infer} model, you cannot change the associated API key.
|
||||||
want to use a different API key, delete the {infer} model and recreate it with
|
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||||
the same name and the updated API key.
|
|
||||||
|
|
||||||
`model_id`:::
|
`model_id`:::
|
||||||
(Required, string)
|
(Required, string)
|
||||||
The name of the model to use for the {infer} task. Refer to the
|
The name of the model to use for the {infer} task.
|
||||||
|
Refer to the
|
||||||
https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation]
|
https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation]
|
||||||
for the list of available text embedding models.
|
for the list of available text embedding models.
|
||||||
|
|
||||||
`organization_id`:::
|
`organization_id`:::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
The unique identifier of your organization. You can find the Organization ID in
|
The unique identifier of your organization.
|
||||||
your OpenAI account under
|
You can find the Organization ID in your OpenAI account under
|
||||||
https://platform.openai.com/account/organization[**Settings** > **Organizations**].
|
https://platform.openai.com/account/organization[**Settings** > **Organizations**].
|
||||||
|
|
||||||
`url`:::
|
`url`:::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
The URL endpoint to use for the requests. Can be changed for testing purposes.
|
The URL endpoint to use for the requests.
|
||||||
|
Can be changed for testing purposes.
|
||||||
Defaults to `https://api.openai.com/v1/embeddings`.
|
Defaults to `https://api.openai.com/v1/embeddings`.
|
||||||
|
|
||||||
|
`rate_limit`:::
|
||||||
|
(Optional, object)
|
||||||
|
The `openai` service sets a default number of requests allowed per minute depending on the task type.
|
||||||
|
For `text_embedding` it is set to `3000`.
|
||||||
|
For `completion` it is set to `500`.
|
||||||
|
This helps to minimize the number of rate limit errors returned from Azure.
|
||||||
|
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||||
|
+
|
||||||
|
[source,text]
|
||||||
|
----
|
||||||
|
"rate_limit": {
|
||||||
|
"requests_per_minute": <<number_of_requests>>
|
||||||
|
}
|
||||||
|
----
|
||||||
|
+
|
||||||
|
More information about the rate limits for OpenAI can be found in your https://platform.openai.com/account/limits[Account limits].
|
||||||
|
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`task_settings`::
|
`task_settings`::
|
||||||
(Optional, object)
|
(Optional, object)
|
||||||
Settings to configure the {infer} task. These settings are specific to the
|
Settings to configure the {infer} task.
|
||||||
|
These settings are specific to the
|
||||||
`<task_type>` you specified.
|
`<task_type>` you specified.
|
||||||
+
|
+
|
||||||
.`task_settings` for the `completion` task type
|
.`task_settings` for the `completion` task type
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`do_sample`:::
|
`do_sample`:::
|
||||||
(Optional, float)
|
(Optional, float)
|
||||||
For the `azureaistudio` service only.
|
For the `azureaistudio` service only.
|
||||||
|
@ -358,8 +420,8 @@ Defaults to 64.
|
||||||
|
|
||||||
`user`:::
|
`user`:::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
For `openai` service only. Specifies the user issuing the request, which can be
|
For `openai` service only.
|
||||||
used for abuse detection.
|
Specifies the user issuing the request, which can be used for abuse detection.
|
||||||
|
|
||||||
`temperature`:::
|
`temperature`:::
|
||||||
(Optional, float)
|
(Optional, float)
|
||||||
|
@ -378,45 +440,46 @@ Should not be used if `temperature` is specified.
|
||||||
.`task_settings` for the `rerank` task type
|
.`task_settings` for the `rerank` task type
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`return_documents`::
|
`return_documents`::
|
||||||
(Optional, boolean)
|
(Optional, boolean)
|
||||||
For `cohere` service only. Specify whether to return doc text within the
|
For `cohere` service only.
|
||||||
results.
|
Specify whether to return doc text within the results.
|
||||||
|
|
||||||
`top_n`::
|
`top_n`::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
The number of most relevant documents to return, defaults to the number of the
|
The number of most relevant documents to return, defaults to the number of the documents.
|
||||||
documents.
|
|
||||||
=====
|
=====
|
||||||
+
|
+
|
||||||
.`task_settings` for the `text_embedding` task type
|
.`task_settings` for the `text_embedding` task type
|
||||||
[%collapsible%closed]
|
[%collapsible%closed]
|
||||||
=====
|
=====
|
||||||
|
|
||||||
`input_type`:::
|
`input_type`:::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
For `cohere` service only. Specifies the type of input passed to the model.
|
For `cohere` service only.
|
||||||
|
Specifies the type of input passed to the model.
|
||||||
Valid values are:
|
Valid values are:
|
||||||
* `classification`: use it for embeddings passed through a text classifier.
|
* `classification`: use it for embeddings passed through a text classifier.
|
||||||
* `clusterning`: use it for the embeddings run through a clustering algorithm.
|
* `clusterning`: use it for the embeddings run through a clustering algorithm.
|
||||||
* `ingest`: use it for storing document embeddings in a vector database.
|
* `ingest`: use it for storing document embeddings in a vector database.
|
||||||
* `search`: use it for storing embeddings of search queries run against a
|
* `search`: use it for storing embeddings of search queries run against a vector database to find relevant documents.
|
||||||
vector database to find relevant documents.
|
|
||||||
|
|
||||||
`truncate`:::
|
`truncate`:::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
For `cohere` service only. Specifies how the API handles inputs longer than the
|
For `cohere` service only.
|
||||||
maximum token length. Defaults to `END`. Valid values are:
|
Specifies how the API handles inputs longer than the maximum token length.
|
||||||
* `NONE`: when the input exceeds the maximum input token length an error is
|
Defaults to `END`.
|
||||||
returned.
|
Valid values are:
|
||||||
* `START`: when the input exceeds the maximum input token length the start of
|
* `NONE`: when the input exceeds the maximum input token length an error is returned.
|
||||||
the input is discarded.
|
* `START`: when the input exceeds the maximum input token length the start of the input is discarded.
|
||||||
* `END`: when the input exceeds the maximum input token length the end of
|
* `END`: when the input exceeds the maximum input token length the end of the input is discarded.
|
||||||
the input is discarded.
|
|
||||||
|
|
||||||
`user`:::
|
`user`:::
|
||||||
(optional, string)
|
(optional, string)
|
||||||
For `openai`, `azureopenai` and `azureaistudio` services only. Specifies the user issuing the
|
For `openai`, `azureopenai` and `azureaistudio` services only.
|
||||||
request, which can be used for abuse detection.
|
Specifies the user issuing the request, which can be used for abuse detection.
|
||||||
|
|
||||||
=====
|
=====
|
||||||
[discrete]
|
[discrete]
|
||||||
|
@ -470,7 +533,6 @@ PUT _inference/completion/azure_ai_studio_completion
|
||||||
|
|
||||||
The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer].
|
The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer].
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-azureopenai]]
|
[[inference-example-azureopenai]]
|
||||||
===== Azure OpenAI service
|
===== Azure OpenAI service
|
||||||
|
@ -519,7 +581,6 @@ The list of chat completion models that you can choose from in your Azure OpenAI
|
||||||
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models]
|
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models]
|
||||||
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5]
|
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5]
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-cohere]]
|
[[inference-example-cohere]]
|
||||||
===== Cohere service
|
===== Cohere service
|
||||||
|
@ -565,7 +626,6 @@ PUT _inference/rerank/cohere-rerank
|
||||||
For more examples, also review the
|
For more examples, also review the
|
||||||
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
|
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-e5]]
|
[[inference-example-e5]]
|
||||||
===== E5 via the `elasticsearch` service
|
===== E5 via the `elasticsearch` service
|
||||||
|
@ -586,10 +646,9 @@ PUT _inference/text_embedding/my-e5-model
|
||||||
}
|
}
|
||||||
------------------------------------------------------------
|
------------------------------------------------------------
|
||||||
// TEST[skip:TBD]
|
// TEST[skip:TBD]
|
||||||
<1> The `model_id` must be the ID of one of the built-in E5 models. Valid values
|
<1> The `model_id` must be the ID of one of the built-in E5 models.
|
||||||
are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. For
|
Valid values are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`.
|
||||||
further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
|
For further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-elser]]
|
[[inference-example-elser]]
|
||||||
|
@ -597,8 +656,7 @@ further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
|
||||||
|
|
||||||
The following example shows how to create an {infer} endpoint called
|
The following example shows how to create an {infer} endpoint called
|
||||||
`my-elser-model` to perform a `sparse_embedding` task type.
|
`my-elser-model` to perform a `sparse_embedding` task type.
|
||||||
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more
|
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more info.
|
||||||
info.
|
|
||||||
|
|
||||||
[source,console]
|
[source,console]
|
||||||
------------------------------------------------------------
|
------------------------------------------------------------
|
||||||
|
@ -672,16 +730,17 @@ PUT _inference/text_embedding/hugging-face-embeddings
|
||||||
}
|
}
|
||||||
------------------------------------------------------------
|
------------------------------------------------------------
|
||||||
// TEST[skip:TBD]
|
// TEST[skip:TBD]
|
||||||
<1> A valid Hugging Face access token. You can find on the
|
<1> A valid Hugging Face access token.
|
||||||
|
You can find on the
|
||||||
https://huggingface.co/settings/tokens[settings page of your account].
|
https://huggingface.co/settings/tokens[settings page of your account].
|
||||||
<2> The {infer} endpoint URL you created on Hugging Face.
|
<2> The {infer} endpoint URL you created on Hugging Face.
|
||||||
|
|
||||||
Create a new {infer} endpoint on
|
Create a new {infer} endpoint on
|
||||||
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an
|
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an endpoint URL.
|
||||||
endpoint URL. Select the model you want to use on the new endpoint creation page
|
Select the model you want to use on the new endpoint creation page - for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
|
||||||
- for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
|
task under the Advanced configuration section.
|
||||||
task under the Advanced configuration section. Create the endpoint. Copy the URL
|
Create the endpoint.
|
||||||
after the endpoint initialization has been finished.
|
Copy the URL after the endpoint initialization has been finished.
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-hugging-face-supported-models]]
|
[[inference-example-hugging-face-supported-models]]
|
||||||
|
@ -695,7 +754,6 @@ The list of recommended models for the Hugging Face service:
|
||||||
* https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base]
|
* https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base]
|
||||||
* https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small]
|
* https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small]
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-eland]]
|
[[inference-example-eland]]
|
||||||
===== Models uploaded by Eland via the elasticsearch service
|
===== Models uploaded by Eland via the elasticsearch service
|
||||||
|
@ -716,11 +774,9 @@ PUT _inference/text_embedding/my-msmarco-minilm-model
|
||||||
}
|
}
|
||||||
------------------------------------------------------------
|
------------------------------------------------------------
|
||||||
// TEST[skip:TBD]
|
// TEST[skip:TBD]
|
||||||
<1> The `model_id` must be the ID of a text embedding model which has already
|
<1> The `model_id` must be the ID of a text embedding model which has already been
|
||||||
been
|
|
||||||
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
|
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
|
||||||
|
|
||||||
|
|
||||||
[discrete]
|
[discrete]
|
||||||
[[inference-example-openai]]
|
[[inference-example-openai]]
|
||||||
===== OpenAI service
|
===== OpenAI service
|
||||||
|
@ -756,4 +812,3 @@ PUT _inference/completion/openai-completion
|
||||||
}
|
}
|
||||||
------------------------------------------------------------
|
------------------------------------------------------------
|
||||||
// TEST[skip:TBD]
|
// TEST[skip:TBD]
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,13 @@ public class TimeValue implements Comparable<TimeValue> {
|
||||||
return new TimeValue(days, TimeUnit.DAYS);
|
return new TimeValue(days, TimeUnit.DAYS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the {@link TimeValue} object that has the least duration.
|
||||||
|
*/
|
||||||
|
public static TimeValue min(TimeValue time1, TimeValue time2) {
|
||||||
|
return time1.compareTo(time2) < 0 ? time1 : time2;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return the unit used for the this time value, see {@link #duration()}
|
* @return the unit used for the this time value, see {@link #duration()}
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -17,6 +17,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
|
||||||
import static org.hamcrest.CoreMatchers.not;
|
import static org.hamcrest.CoreMatchers.not;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.lessThan;
|
import static org.hamcrest.Matchers.lessThan;
|
||||||
import static org.hamcrest.object.HasToString.hasToString;
|
import static org.hamcrest.object.HasToString.hasToString;
|
||||||
|
|
||||||
|
@ -231,6 +232,12 @@ public class TimeValueTests extends ESTestCase {
|
||||||
assertThat(ex.getMessage(), containsString("duration cannot be negative"));
|
assertThat(ex.getMessage(), containsString("duration cannot be negative"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testMin() {
|
||||||
|
assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0)));
|
||||||
|
assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1)));
|
||||||
|
assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE));
|
||||||
|
}
|
||||||
|
|
||||||
private TimeUnit randomTimeUnitObject() {
|
private TimeUnit randomTimeUnitObject() {
|
||||||
return randomFrom(
|
return randomFrom(
|
||||||
TimeUnit.NANOSECONDS,
|
TimeUnit.NANOSECONDS,
|
||||||
|
|
|
@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor {
|
||||||
private final ServiceComponents serviceComponents;
|
private final ServiceComponents serviceComponents;
|
||||||
|
|
||||||
public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
|
public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
|
||||||
|
// TODO Batching - accept a class that can handle batching
|
||||||
this.sender = Objects.requireNonNull(sender);
|
this.sender = Objects.requireNonNull(sender);
|
||||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,6 +36,7 @@ public class CohereEmbeddingsAction implements ExecutableAction {
|
||||||
model.getServiceSettings().getCommonSettings().uri(),
|
model.getServiceSettings().getCommonSettings().uri(),
|
||||||
"Cohere embeddings"
|
"Cohere embeddings"
|
||||||
);
|
);
|
||||||
|
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
|
||||||
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
|
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -37,17 +36,16 @@ public class AzureAiStudioChatCompletionRequestManager extends AzureAiStudioRequ
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input);
|
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ResponseHandler createCompletionHandler() {
|
private static ResponseHandler createCompletionHandler() {
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -41,17 +40,16 @@ public class AzureAiStudioEmbeddingsRequestManager extends AzureAiStudioRequestM
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||||
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
|
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ResponseHandler createEmbeddingsHandler() {
|
private static ResponseHandler createEmbeddingsHandler() {
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -43,16 +42,15 @@ public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManag
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
|
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -55,16 +54,15 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||||
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
|
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,11 +38,16 @@ abstract class BaseRequestManager implements RequestManager {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Object rateLimitGrouping() {
|
public Object rateLimitGrouping() {
|
||||||
return rateLimitGroup;
|
// It's possible that two inference endpoints have the same information defining the group but have different
|
||||||
|
// rate limits then they should be in different groups otherwise whoever initially created the group will set
|
||||||
|
// the rate and the other inference endpoint's rate will be ignored
|
||||||
|
return new EndpointGrouping(rateLimitGroup, rateLimitSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RateLimitSettings rateLimitSettings() {
|
public RateLimitSettings rateLimitSettings() {
|
||||||
return rateLimitSettings;
|
return rateLimitSettings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private record EndpointGrouping(Object group, RateLimitSettings settings) {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -46,16 +45,15 @@ public class CohereCompletionRequestManager extends CohereRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
CohereCompletionRequest request = new CohereCompletionRequest(input, model);
|
CohereCompletionRequest request = new CohereCompletionRequest(input, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -44,16 +43,15 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);
|
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -44,16 +43,15 @@ public class CohereRerankRequestManager extends CohereRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
CohereRerankRequest request = new CohereRerankRequest(query, input, model);
|
CohereRerankRequest request = new CohereRerankRequest(query, input, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ record ExecutableInferenceRequest(
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Logger logger,
|
Logger logger,
|
||||||
Request request,
|
Request request,
|
||||||
HttpClientContext context,
|
|
||||||
ResponseHandler responseHandler,
|
ResponseHandler responseHandler,
|
||||||
Supplier<Boolean> hasFinished,
|
Supplier<Boolean> hasFinished,
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
@ -34,7 +33,7 @@ record ExecutableInferenceRequest(
|
||||||
var inferenceEntityId = request.createHttpRequest().inferenceEntityId();
|
var inferenceEntityId = request.createHttpRequest().inferenceEntityId();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
requestSender.send(logger, request, context, hasFinished, responseHandler, listener);
|
requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
|
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
|
||||||
logger.warn(errorMessage, e);
|
logger.warn(errorMessage, e);
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -42,15 +41,14 @@ public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioReques
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
|
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -48,17 +47,16 @@ public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioReques
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||||
GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
|
GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@ import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
|
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
|
@ -39,30 +41,28 @@ public class HttpRequestSender implements Sender {
|
||||||
private final ServiceComponents serviceComponents;
|
private final ServiceComponents serviceComponents;
|
||||||
private final HttpClientManager httpClientManager;
|
private final HttpClientManager httpClientManager;
|
||||||
private final ClusterService clusterService;
|
private final ClusterService clusterService;
|
||||||
private final SingleRequestManager requestManager;
|
private final RequestSender requestSender;
|
||||||
|
|
||||||
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
|
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
|
||||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||||
this.httpClientManager = Objects.requireNonNull(httpClientManager);
|
this.httpClientManager = Objects.requireNonNull(httpClientManager);
|
||||||
this.clusterService = Objects.requireNonNull(clusterService);
|
this.clusterService = Objects.requireNonNull(clusterService);
|
||||||
|
|
||||||
var requestSender = new RetryingHttpSender(
|
requestSender = new RetryingHttpSender(
|
||||||
this.httpClientManager.getHttpClient(),
|
this.httpClientManager.getHttpClient(),
|
||||||
serviceComponents.throttlerManager(),
|
serviceComponents.throttlerManager(),
|
||||||
new RetrySettings(serviceComponents.settings(), clusterService),
|
new RetrySettings(serviceComponents.settings(), clusterService),
|
||||||
serviceComponents.threadPool()
|
serviceComponents.threadPool()
|
||||||
);
|
);
|
||||||
requestManager = new SingleRequestManager(requestSender);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Sender createSender(String serviceName) {
|
public Sender createSender() {
|
||||||
return new HttpRequestSender(
|
return new HttpRequestSender(
|
||||||
serviceName,
|
|
||||||
serviceComponents.threadPool(),
|
serviceComponents.threadPool(),
|
||||||
httpClientManager,
|
httpClientManager,
|
||||||
clusterService,
|
clusterService,
|
||||||
serviceComponents.settings(),
|
serviceComponents.settings(),
|
||||||
requestManager
|
requestSender
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,26 +71,24 @@ public class HttpRequestSender implements Sender {
|
||||||
|
|
||||||
private final ThreadPool threadPool;
|
private final ThreadPool threadPool;
|
||||||
private final HttpClientManager manager;
|
private final HttpClientManager manager;
|
||||||
private final RequestExecutorService service;
|
private final RequestExecutor service;
|
||||||
private final AtomicBoolean started = new AtomicBoolean(false);
|
private final AtomicBoolean started = new AtomicBoolean(false);
|
||||||
private final CountDownLatch startCompleted = new CountDownLatch(1);
|
private final CountDownLatch startCompleted = new CountDownLatch(1);
|
||||||
|
|
||||||
private HttpRequestSender(
|
private HttpRequestSender(
|
||||||
String serviceName,
|
|
||||||
ThreadPool threadPool,
|
ThreadPool threadPool,
|
||||||
HttpClientManager httpClientManager,
|
HttpClientManager httpClientManager,
|
||||||
ClusterService clusterService,
|
ClusterService clusterService,
|
||||||
Settings settings,
|
Settings settings,
|
||||||
SingleRequestManager requestManager
|
RequestSender requestSender
|
||||||
) {
|
) {
|
||||||
this.threadPool = Objects.requireNonNull(threadPool);
|
this.threadPool = Objects.requireNonNull(threadPool);
|
||||||
this.manager = Objects.requireNonNull(httpClientManager);
|
this.manager = Objects.requireNonNull(httpClientManager);
|
||||||
service = new RequestExecutorService(
|
service = new RequestExecutorService(
|
||||||
serviceName,
|
|
||||||
threadPool,
|
threadPool,
|
||||||
startCompleted,
|
startCompleted,
|
||||||
new RequestExecutorServiceSettings(settings, clusterService),
|
new RequestExecutorServiceSettings(settings, clusterService),
|
||||||
requestManager
|
requestSender
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -55,26 +54,17 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var truncatedInput = truncate(input, model.getTokenLimit());
|
var truncatedInput = truncate(input, model.getTokenLimit());
|
||||||
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
|
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
|
||||||
requestSender,
|
|
||||||
logger,
|
|
||||||
request,
|
|
||||||
context,
|
|
||||||
responseHandler,
|
|
||||||
hasRequestCompletedFunction,
|
|
||||||
listener
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
record RateLimitGrouping(int accountHash) {
|
record RateLimitGrouping(int accountHash) {
|
||||||
|
|
|
@ -19,9 +19,9 @@ import java.util.function.Supplier;
|
||||||
public interface InferenceRequest {
|
public interface InferenceRequest {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the creator that handles building an executable request based on the input provided.
|
* Returns the manager that handles building and executing an inference request.
|
||||||
*/
|
*/
|
||||||
RequestManager getRequestCreator();
|
RequestManager getRequestManager();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the query associated with this request. Used for Rerank tasks.
|
* Returns the query associated with this request. Used for Rerank tasks.
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -51,18 +50,17 @@ public class MistralEmbeddingsRequestManager extends BaseRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||||
MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model);
|
MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
||||||
record RateLimitGrouping(int keyHashCode) {
|
record RateLimitGrouping(int keyHashCode) {
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
||||||
* or more contributor license agreements. Licensed under the Elastic License
|
|
||||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
|
||||||
* 2.0.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
|
||||||
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
|
|
||||||
class NoopTask implements RejectableTask {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RequestManager getRequestCreator() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getQuery() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getInput() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ActionListener<InferenceServiceResults> getListener() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasCompleted() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Supplier<Boolean> getRequestCompletedFunction() {
|
|
||||||
return () -> true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onRejection(Exception e) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -43,17 +42,16 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model);
|
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ResponseHandler createCompletionHandler() {
|
private static ResponseHandler createCompletionHandler() {
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -55,17 +54,16 @@ public class OpenAiEmbeddingsRequestManager extends OpenAiRequestManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Runnable create(
|
public void execute(
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||||
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model);
|
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model);
|
||||||
|
|
||||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
@ -17,21 +16,31 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.core.Strings;
|
import org.elasticsearch.core.Strings;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.threadpool.Scheduler;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue;
|
import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue;
|
||||||
|
import org.elasticsearch.xpack.inference.common.RateLimiter;
|
||||||
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
|
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Instant;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.BlockingQueue;
|
import java.util.concurrent.BlockingQueue;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.ConcurrentMap;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.LinkedBlockingQueue;
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.function.Consumer;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
import static org.elasticsearch.core.Strings.format;
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A service for queuing and executing {@link RequestTask}. This class is useful because the
|
* A service for queuing and executing {@link RequestTask}. This class is useful because the
|
||||||
|
@ -45,7 +54,18 @@ import static org.elasticsearch.core.Strings.format;
|
||||||
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
|
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
|
||||||
*/
|
*/
|
||||||
class RequestExecutorService implements RequestExecutor {
|
class RequestExecutorService implements RequestExecutor {
|
||||||
private static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> QUEUE_CREATOR =
|
|
||||||
|
/**
|
||||||
|
* Provides dependency injection mainly for testing
|
||||||
|
*/
|
||||||
|
interface Sleeper {
|
||||||
|
void sleep(TimeValue sleepTime) throws InterruptedException;
|
||||||
|
}
|
||||||
|
|
||||||
|
// default for tests
|
||||||
|
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
|
||||||
|
// default for tests
|
||||||
|
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
|
||||||
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
|
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
|
||||||
@Override
|
@Override
|
||||||
public BlockingQueue<RejectableTask> create(int capacity) {
|
public BlockingQueue<RejectableTask> create(int capacity) {
|
||||||
|
@ -65,86 +85,116 @@ class RequestExecutorService implements RequestExecutor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provides dependency injection mainly for testing
|
||||||
|
*/
|
||||||
|
interface RateLimiterCreator {
|
||||||
|
RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit);
|
||||||
|
}
|
||||||
|
|
||||||
|
// default for testing
|
||||||
|
static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new;
|
||||||
private static final Logger logger = LogManager.getLogger(RequestExecutorService.class);
|
private static final Logger logger = LogManager.getLogger(RequestExecutorService.class);
|
||||||
private final String serviceName;
|
private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
|
||||||
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
|
|
||||||
private final AtomicBoolean running = new AtomicBoolean(true);
|
private final ConcurrentMap<Object, RateLimitingEndpointHandler> rateLimitGroupings = new ConcurrentHashMap<>();
|
||||||
private final CountDownLatch terminationLatch = new CountDownLatch(1);
|
|
||||||
private final HttpClientContext httpContext;
|
|
||||||
private final ThreadPool threadPool;
|
private final ThreadPool threadPool;
|
||||||
private final CountDownLatch startupLatch;
|
private final CountDownLatch startupLatch;
|
||||||
private final BlockingQueue<Runnable> controlQueue = new LinkedBlockingQueue<>();
|
private final CountDownLatch terminationLatch = new CountDownLatch(1);
|
||||||
private final SingleRequestManager requestManager;
|
private final RequestSender requestSender;
|
||||||
|
private final RequestExecutorServiceSettings settings;
|
||||||
|
private final Clock clock;
|
||||||
|
private final AtomicBoolean shutdown = new AtomicBoolean(false);
|
||||||
|
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
|
||||||
|
private final Sleeper sleeper;
|
||||||
|
private final RateLimiterCreator rateLimiterCreator;
|
||||||
|
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
|
||||||
|
private final AtomicBoolean started = new AtomicBoolean(false);
|
||||||
|
|
||||||
RequestExecutorService(
|
RequestExecutorService(
|
||||||
String serviceName,
|
|
||||||
ThreadPool threadPool,
|
ThreadPool threadPool,
|
||||||
@Nullable CountDownLatch startupLatch,
|
@Nullable CountDownLatch startupLatch,
|
||||||
RequestExecutorServiceSettings settings,
|
RequestExecutorServiceSettings settings,
|
||||||
SingleRequestManager requestManager
|
RequestSender requestSender
|
||||||
) {
|
) {
|
||||||
this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager);
|
this(
|
||||||
|
threadPool,
|
||||||
|
DEFAULT_QUEUE_CREATOR,
|
||||||
|
startupLatch,
|
||||||
|
settings,
|
||||||
|
requestSender,
|
||||||
|
Clock.systemUTC(),
|
||||||
|
DEFAULT_SLEEPER,
|
||||||
|
DEFAULT_RATE_LIMIT_CREATOR
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This constructor should only be used directly for testing.
|
|
||||||
*/
|
|
||||||
RequestExecutorService(
|
RequestExecutorService(
|
||||||
String serviceName,
|
|
||||||
ThreadPool threadPool,
|
ThreadPool threadPool,
|
||||||
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
|
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator,
|
||||||
@Nullable CountDownLatch startupLatch,
|
@Nullable CountDownLatch startupLatch,
|
||||||
RequestExecutorServiceSettings settings,
|
RequestExecutorServiceSettings settings,
|
||||||
SingleRequestManager requestManager
|
RequestSender requestSender,
|
||||||
|
Clock clock,
|
||||||
|
Sleeper sleeper,
|
||||||
|
RateLimiterCreator rateLimiterCreator
|
||||||
) {
|
) {
|
||||||
this.serviceName = Objects.requireNonNull(serviceName);
|
|
||||||
this.threadPool = Objects.requireNonNull(threadPool);
|
this.threadPool = Objects.requireNonNull(threadPool);
|
||||||
this.httpContext = HttpClientContext.create();
|
this.queueCreator = Objects.requireNonNull(queueCreator);
|
||||||
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
|
|
||||||
this.startupLatch = startupLatch;
|
this.startupLatch = startupLatch;
|
||||||
this.requestManager = Objects.requireNonNull(requestManager);
|
this.requestSender = Objects.requireNonNull(requestSender);
|
||||||
|
this.settings = Objects.requireNonNull(settings);
|
||||||
Objects.requireNonNull(settings);
|
this.clock = Objects.requireNonNull(clock);
|
||||||
settings.registerQueueCapacityCallback(this::onCapacityChange);
|
this.sleeper = Objects.requireNonNull(sleeper);
|
||||||
|
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void onCapacityChange(int capacity) {
|
public void shutdown() {
|
||||||
logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity));
|
if (shutdown.compareAndSet(false, true)) {
|
||||||
|
if (cancellableCleanupTask.get() != null) {
|
||||||
var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity));
|
logger.debug(() -> "Stopping clean up thread");
|
||||||
if (enqueuedCapacityCommand == false) {
|
cancellableCleanupTask.get().cancel();
|
||||||
logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later.");
|
}
|
||||||
} else {
|
|
||||||
// ensure that the task execution loop wakes up
|
|
||||||
queue.offer(new NoopTask());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateCapacity(int newCapacity) {
|
public boolean isShutdown() {
|
||||||
try {
|
return shutdown.get();
|
||||||
queue.setCapacity(newCapacity);
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
logger.warn(
|
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||||
format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName),
|
return terminationLatch.await(timeout, unit);
|
||||||
e
|
}
|
||||||
);
|
|
||||||
}
|
public boolean isTerminated() {
|
||||||
|
return terminationLatch.getCount() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int queueSize() {
|
||||||
|
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Begin servicing tasks.
|
* Begin servicing tasks.
|
||||||
|
* <p>
|
||||||
|
* <b>Note: This should only be called once for the life of the object.</b>
|
||||||
|
* </p>
|
||||||
*/
|
*/
|
||||||
public void start() {
|
public void start() {
|
||||||
try {
|
try {
|
||||||
|
assert started.get() == false : "start() can only be called once";
|
||||||
|
started.set(true);
|
||||||
|
|
||||||
|
startCleanupTask();
|
||||||
signalStartInitiated();
|
signalStartInitiated();
|
||||||
|
|
||||||
while (running.get()) {
|
while (isShutdown() == false) {
|
||||||
handleTasks();
|
handleTasks();
|
||||||
}
|
}
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
Thread.currentThread().interrupt();
|
Thread.currentThread().interrupt();
|
||||||
} finally {
|
} finally {
|
||||||
running.set(false);
|
shutdown();
|
||||||
notifyRequestsOfShutdown();
|
notifyRequestsOfShutdown();
|
||||||
terminationLatch.countDown();
|
terminationLatch.countDown();
|
||||||
}
|
}
|
||||||
|
@ -156,108 +206,68 @@ class RequestExecutorService implements RequestExecutor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private void startCleanupTask() {
|
||||||
* Protects the task retrieval logic from an unexpected exception.
|
assert cancellableCleanupTask.get() == null : "The clean up task can only be set once";
|
||||||
*
|
cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL));
|
||||||
* @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to
|
}
|
||||||
* shut down
|
|
||||||
*/
|
private Scheduler.Cancellable startCleanupThread(TimeValue interval) {
|
||||||
|
logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval));
|
||||||
|
|
||||||
|
return threadPool.scheduleWithFixedDelay(this::removeStaleGroupings, interval, threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||||
|
}
|
||||||
|
|
||||||
|
// default for testing
|
||||||
|
void removeStaleGroupings() {
|
||||||
|
var now = Instant.now(clock);
|
||||||
|
for (var iter = rateLimitGroupings.values().iterator(); iter.hasNext();) {
|
||||||
|
var endpoint = iter.next();
|
||||||
|
|
||||||
|
// if the current time is after the last time the endpoint enqueued a request + allowed stale period then we'll remove it
|
||||||
|
if (now.isAfter(endpoint.timeOfLastEnqueue().plus(settings.getRateLimitGroupStaleDuration()))) {
|
||||||
|
endpoint.close();
|
||||||
|
iter.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void handleTasks() throws InterruptedException {
|
private void handleTasks() throws InterruptedException {
|
||||||
try {
|
var timeToWait = settings.getTaskPollFrequency();
|
||||||
RejectableTask task = queue.take();
|
for (var endpoint : rateLimitGroupings.values()) {
|
||||||
|
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
|
||||||
var command = controlQueue.poll();
|
|
||||||
if (command != null) {
|
|
||||||
command.run();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO add logic to complete pending items in the queue before shutting down
|
|
||||||
if (running.get() == false) {
|
|
||||||
logger.debug(() -> format("Http executor service [%s] exiting", serviceName));
|
|
||||||
rejectTaskBecauseOfShutdown(task);
|
|
||||||
} else {
|
|
||||||
executeTask(task);
|
|
||||||
}
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
throw e;
|
|
||||||
} catch (Exception e) {
|
|
||||||
logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sleeper.sleep(timeToWait);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeTask(RejectableTask task) {
|
private void notifyRequestsOfShutdown() {
|
||||||
try {
|
|
||||||
requestManager.execute(task, httpContext);
|
|
||||||
} catch (Exception e) {
|
|
||||||
logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private synchronized void notifyRequestsOfShutdown() {
|
|
||||||
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
|
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
|
||||||
|
|
||||||
try {
|
for (var endpoint : rateLimitGroupings.values()) {
|
||||||
List<RejectableTask> notExecuted = new ArrayList<>();
|
endpoint.notifyRequestsOfShutdown();
|
||||||
queue.drainTo(notExecuted);
|
|
||||||
|
|
||||||
rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown);
|
|
||||||
} catch (Exception e) {
|
|
||||||
logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void rejectTaskBecauseOfShutdown(RejectableTask task) {
|
// default for testing
|
||||||
try {
|
Integer remainingQueueCapacity(RequestManager requestManager) {
|
||||||
task.onRejection(
|
var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping());
|
||||||
new EsRejectedExecutionException(
|
|
||||||
format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName),
|
if (endpoint == null) {
|
||||||
true
|
return null;
|
||||||
)
|
|
||||||
);
|
|
||||||
} catch (Exception e) {
|
|
||||||
logger.warn(
|
|
||||||
format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return endpoint.remainingCapacity();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void rejectTasks(List<RejectableTask> tasks, Consumer<RejectableTask> rejectionFunction) {
|
// default for testing
|
||||||
for (var task : tasks) {
|
int numberOfRateLimitGroups() {
|
||||||
rejectionFunction.accept(task);
|
return rateLimitGroupings.size();
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public int queueSize() {
|
|
||||||
return queue.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shutdown() {
|
|
||||||
if (running.compareAndSet(true, false)) {
|
|
||||||
// if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns
|
|
||||||
queue.offer(new NoopTask());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isShutdown() {
|
|
||||||
return running.get() == false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isTerminated() {
|
|
||||||
return terminationLatch.getCount() == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
|
||||||
return terminationLatch.await(timeout, unit);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Execute the request at some point in the future.
|
* Execute the request at some point in the future.
|
||||||
*
|
*
|
||||||
* @param requestCreator the http request to send
|
* @param requestManager the http request to send
|
||||||
* @param inferenceInputs the inputs to send in the request
|
* @param inferenceInputs the inputs to send in the request
|
||||||
* @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the
|
* @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the
|
||||||
* listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}.
|
* listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}.
|
||||||
|
@ -265,13 +275,13 @@ class RequestExecutorService implements RequestExecutor {
|
||||||
* @param listener an {@link ActionListener<InferenceServiceResults>} for the response or failure
|
* @param listener an {@link ActionListener<InferenceServiceResults>} for the response or failure
|
||||||
*/
|
*/
|
||||||
public void execute(
|
public void execute(
|
||||||
RequestManager requestCreator,
|
RequestManager requestManager,
|
||||||
InferenceInputs inferenceInputs,
|
InferenceInputs inferenceInputs,
|
||||||
@Nullable TimeValue timeout,
|
@Nullable TimeValue timeout,
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var task = new RequestTask(
|
var task = new RequestTask(
|
||||||
requestCreator,
|
requestManager,
|
||||||
inferenceInputs,
|
inferenceInputs,
|
||||||
timeout,
|
timeout,
|
||||||
threadPool,
|
threadPool,
|
||||||
|
@ -280,38 +290,230 @@ class RequestExecutorService implements RequestExecutor {
|
||||||
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext())
|
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext())
|
||||||
);
|
);
|
||||||
|
|
||||||
completeExecution(task);
|
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> {
|
||||||
}
|
var endpointHandler = new RateLimitingEndpointHandler(
|
||||||
|
Integer.toString(requestManager.rateLimitGrouping().hashCode()),
|
||||||
private void completeExecution(RequestTask task) {
|
queueCreator,
|
||||||
if (isShutdown()) {
|
settings,
|
||||||
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
|
requestSender,
|
||||||
format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName),
|
clock,
|
||||||
true
|
requestManager.rateLimitSettings(),
|
||||||
|
this::isShutdown,
|
||||||
|
rateLimiterCreator
|
||||||
);
|
);
|
||||||
|
|
||||||
task.onRejection(rejected);
|
endpointHandler.init();
|
||||||
return;
|
return endpointHandler;
|
||||||
}
|
});
|
||||||
|
|
||||||
boolean added = queue.offer(task);
|
endpoint.enqueue(task);
|
||||||
if (added == false) {
|
|
||||||
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
|
|
||||||
format("Failed to execute task because the http executor service [%s] queue is full", serviceName),
|
|
||||||
false
|
|
||||||
);
|
|
||||||
|
|
||||||
task.onRejection(rejected);
|
|
||||||
} else if (isShutdown()) {
|
|
||||||
// It is possible that a shutdown and notification request occurred after we initially checked for shutdown above
|
|
||||||
// If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if
|
|
||||||
// we shut down and if so we'll redo the notification
|
|
||||||
notifyRequestsOfShutdown();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// default for testing
|
/**
|
||||||
int remainingQueueCapacity() {
|
* Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests.
|
||||||
return queue.remainingCapacity();
|
* This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent
|
||||||
|
* as soon as a thread is available.
|
||||||
|
*/
|
||||||
|
private static class RateLimitingEndpointHandler {
|
||||||
|
|
||||||
|
private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE;
|
||||||
|
private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO;
|
||||||
|
private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class);
|
||||||
|
private static final int ACCUMULATED_TOKENS_LIMIT = 1;
|
||||||
|
|
||||||
|
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
|
||||||
|
private final Supplier<Boolean> isShutdownMethod;
|
||||||
|
private final RequestSender requestSender;
|
||||||
|
private final String id;
|
||||||
|
private final AtomicReference<Instant> timeOfLastEnqueue = new AtomicReference<>();
|
||||||
|
private final Clock clock;
|
||||||
|
private final RateLimiter rateLimiter;
|
||||||
|
private final RequestExecutorServiceSettings requestExecutorServiceSettings;
|
||||||
|
|
||||||
|
RateLimitingEndpointHandler(
|
||||||
|
String id,
|
||||||
|
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
|
||||||
|
RequestExecutorServiceSettings settings,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Clock clock,
|
||||||
|
RateLimitSettings rateLimitSettings,
|
||||||
|
Supplier<Boolean> isShutdownMethod,
|
||||||
|
RateLimiterCreator rateLimiterCreator
|
||||||
|
) {
|
||||||
|
this.requestExecutorServiceSettings = Objects.requireNonNull(settings);
|
||||||
|
this.id = Objects.requireNonNull(id);
|
||||||
|
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
|
||||||
|
this.requestSender = Objects.requireNonNull(requestSender);
|
||||||
|
this.clock = Objects.requireNonNull(clock);
|
||||||
|
this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod);
|
||||||
|
|
||||||
|
Objects.requireNonNull(rateLimitSettings);
|
||||||
|
Objects.requireNonNull(rateLimiterCreator);
|
||||||
|
rateLimiter = rateLimiterCreator.create(
|
||||||
|
ACCUMULATED_TOKENS_LIMIT,
|
||||||
|
rateLimitSettings.requestsPerTimeUnit(),
|
||||||
|
rateLimitSettings.timeUnit()
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void init() {
|
||||||
|
requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void onCapacityChange(int capacity) {
|
||||||
|
logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity));
|
||||||
|
|
||||||
|
try {
|
||||||
|
queue.setCapacity(capacity);
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int queueSize() {
|
||||||
|
return queue.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isShutdown() {
|
||||||
|
return isShutdownMethod.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Instant timeOfLastEnqueue() {
|
||||||
|
return timeOfLastEnqueue.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
public synchronized TimeValue executeEnqueuedTask() {
|
||||||
|
try {
|
||||||
|
return executeEnqueuedTaskInternal();
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.warn(format("Executor service grouping [%s] failed to execute request", id), e);
|
||||||
|
// we tried to do some work but failed, so we'll say we did something to try looking for more work
|
||||||
|
return EXECUTED_A_TASK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private TimeValue executeEnqueuedTaskInternal() {
|
||||||
|
var timeBeforeAvailableToken = rateLimiter.timeToReserve(1);
|
||||||
|
if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) {
|
||||||
|
return timeBeforeAvailableToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
var task = queue.poll();
|
||||||
|
|
||||||
|
// TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks
|
||||||
|
// So we'll need to check for null and call a helper method executePreparedTasks()
|
||||||
|
|
||||||
|
if (shouldExecuteTask(task) == false) {
|
||||||
|
return NO_TASKS_AVAILABLE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should never have to wait because we checked above
|
||||||
|
var reserveRes = rateLimiter.reserve(1);
|
||||||
|
assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have";
|
||||||
|
|
||||||
|
task.getRequestManager()
|
||||||
|
.execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener());
|
||||||
|
return EXECUTED_A_TASK;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean shouldExecuteTask(RejectableTask task) {
|
||||||
|
return task != null && isNoopRequest(task) == false && task.hasCompleted() == false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
|
||||||
|
return inferenceRequest.getRequestManager() == null
|
||||||
|
|| inferenceRequest.getInput() == null
|
||||||
|
|| inferenceRequest.getListener() == null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean shouldExecuteImmediately(TimeValue delay) {
|
||||||
|
return delay.duration() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void enqueue(RequestTask task) {
|
||||||
|
timeOfLastEnqueue.set(Instant.now(clock));
|
||||||
|
|
||||||
|
if (isShutdown()) {
|
||||||
|
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
|
||||||
|
format(
|
||||||
|
"Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown",
|
||||||
|
task.getRequestManager().inferenceEntityId(),
|
||||||
|
id
|
||||||
|
),
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
task.onRejection(rejected);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var addedToQueue = queue.offer(task);
|
||||||
|
|
||||||
|
if (addedToQueue == false) {
|
||||||
|
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
|
||||||
|
format(
|
||||||
|
"Failed to execute task for inference id [%s] because the request service [%s] queue is full",
|
||||||
|
task.getRequestManager().inferenceEntityId(),
|
||||||
|
id
|
||||||
|
),
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
task.onRejection(rejected);
|
||||||
|
} else if (isShutdown()) {
|
||||||
|
notifyRequestsOfShutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public synchronized void notifyRequestsOfShutdown() {
|
||||||
|
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
|
||||||
|
|
||||||
|
try {
|
||||||
|
List<RejectableTask> notExecuted = new ArrayList<>();
|
||||||
|
queue.drainTo(notExecuted);
|
||||||
|
|
||||||
|
rejectTasks(notExecuted);
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rejectTasks(List<RejectableTask> tasks) {
|
||||||
|
for (var task : tasks) {
|
||||||
|
rejectTaskForShutdown(task);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rejectTaskForShutdown(RejectableTask task) {
|
||||||
|
try {
|
||||||
|
task.onRejection(
|
||||||
|
new EsRejectedExecutionException(
|
||||||
|
format(
|
||||||
|
"Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request",
|
||||||
|
id,
|
||||||
|
task.getRequestManager().inferenceEntityId()
|
||||||
|
),
|
||||||
|
true
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.warn(
|
||||||
|
format(
|
||||||
|
"Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown",
|
||||||
|
task.getRequestManager().inferenceEntityId(),
|
||||||
|
id
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int remainingCapacity() {
|
||||||
|
return queue.remainingCapacity();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void close() {
|
||||||
|
requestExecutorServiceSettings.deregisterQueueCapacityCallback(id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,9 +10,12 @@ package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
import org.elasticsearch.cluster.service.ClusterService;
|
import org.elasticsearch.cluster.service.ClusterService;
|
||||||
import org.elasticsearch.common.settings.Setting;
|
import org.elasticsearch.common.settings.Setting;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.time.Duration;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.ConcurrentMap;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public class RequestExecutorServiceSettings {
|
public class RequestExecutorServiceSettings {
|
||||||
|
@ -29,37 +32,108 @@ public class RequestExecutorServiceSettings {
|
||||||
Setting.Property.Dynamic
|
Setting.Property.Dynamic
|
||||||
);
|
);
|
||||||
|
|
||||||
|
private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50);
|
||||||
|
/**
|
||||||
|
* Defines how often all the rate limit groups are polled for tasks. Setting this to very low number could result
|
||||||
|
* in a busy loop if there are no tasks available to handle.
|
||||||
|
*/
|
||||||
|
static final Setting<TimeValue> TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting(
|
||||||
|
"xpack.inference.http.request_executor.task_poll_frequency",
|
||||||
|
DEFAULT_TASK_POLL_FREQUENCY_TIME,
|
||||||
|
Setting.Property.NodeScope,
|
||||||
|
Setting.Property.Dynamic
|
||||||
|
);
|
||||||
|
|
||||||
|
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
|
||||||
|
/**
|
||||||
|
* Defines how often a thread will check for rate limit groups that are stale.
|
||||||
|
*/
|
||||||
|
static final Setting<TimeValue> RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING = Setting.timeSetting(
|
||||||
|
"xpack.inference.http.request_executor.rate_limit_group_cleanup_interval",
|
||||||
|
DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL,
|
||||||
|
Setting.Property.NodeScope,
|
||||||
|
Setting.Property.Dynamic
|
||||||
|
);
|
||||||
|
|
||||||
|
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION = TimeValue.timeValueDays(10);
|
||||||
|
/**
|
||||||
|
* Defines the amount of time it takes to classify a rate limit group as stale. Once it is classified as stale,
|
||||||
|
* it can be removed when the cleanup thread executes.
|
||||||
|
*/
|
||||||
|
static final Setting<TimeValue> RATE_LIMIT_GROUP_STALE_DURATION_SETTING = Setting.timeSetting(
|
||||||
|
"xpack.inference.http.request_executor.rate_limit_group_stale_duration",
|
||||||
|
DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION,
|
||||||
|
Setting.Property.NodeScope,
|
||||||
|
Setting.Property.Dynamic
|
||||||
|
);
|
||||||
|
|
||||||
public static List<Setting<?>> getSettingsDefinitions() {
|
public static List<Setting<?>> getSettingsDefinitions() {
|
||||||
return List.of(TASK_QUEUE_CAPACITY_SETTING);
|
return List.of(
|
||||||
|
TASK_QUEUE_CAPACITY_SETTING,
|
||||||
|
TASK_POLL_FREQUENCY_SETTING,
|
||||||
|
RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING,
|
||||||
|
RATE_LIMIT_GROUP_STALE_DURATION_SETTING
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private volatile int queueCapacity;
|
private volatile int queueCapacity;
|
||||||
private final List<Consumer<Integer>> queueCapacityCallbacks = new ArrayList<Consumer<Integer>>();
|
private volatile TimeValue taskPollFrequency;
|
||||||
|
private volatile Duration rateLimitGroupStaleDuration;
|
||||||
|
private final ConcurrentMap<String, Consumer<Integer>> queueCapacityCallbacks = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) {
|
public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) {
|
||||||
queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings);
|
queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings);
|
||||||
|
taskPollFrequency = TASK_POLL_FREQUENCY_SETTING.get(settings);
|
||||||
|
setRateLimitGroupStaleDuration(RATE_LIMIT_GROUP_STALE_DURATION_SETTING.get(settings));
|
||||||
|
|
||||||
addSettingsUpdateConsumers(clusterService);
|
addSettingsUpdateConsumers(clusterService);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addSettingsUpdateConsumers(ClusterService clusterService) {
|
private void addSettingsUpdateConsumers(ClusterService clusterService) {
|
||||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity);
|
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity);
|
||||||
|
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency);
|
||||||
|
clusterService.getClusterSettings()
|
||||||
|
.addSettingsUpdateConsumer(RATE_LIMIT_GROUP_STALE_DURATION_SETTING, this::setRateLimitGroupStaleDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
// default for testing
|
// default for testing
|
||||||
void setQueueCapacity(int queueCapacity) {
|
void setQueueCapacity(int queueCapacity) {
|
||||||
this.queueCapacity = queueCapacity;
|
this.queueCapacity = queueCapacity;
|
||||||
|
|
||||||
for (var callback : queueCapacityCallbacks) {
|
for (var callback : queueCapacityCallbacks.values()) {
|
||||||
callback.accept(queueCapacity);
|
callback.accept(queueCapacity);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerQueueCapacityCallback(Consumer<Integer> onChangeCapacityCallback) {
|
private void setTaskPollFrequency(TimeValue taskPollFrequency) {
|
||||||
queueCapacityCallbacks.add(onChangeCapacityCallback);
|
this.taskPollFrequency = taskPollFrequency;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setRateLimitGroupStaleDuration(TimeValue staleDuration) {
|
||||||
|
rateLimitGroupStaleDuration = toDuration(staleDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Duration toDuration(TimeValue timeValue) {
|
||||||
|
return Duration.of(timeValue.duration(), timeValue.timeUnit().toChronoUnit());
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerQueueCapacityCallback(String id, Consumer<Integer> onChangeCapacityCallback) {
|
||||||
|
queueCapacityCallbacks.put(id, onChangeCapacityCallback);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deregisterQueueCapacityCallback(String id) {
|
||||||
|
queueCapacityCallbacks.remove(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
int getQueueCapacity() {
|
int getQueueCapacity() {
|
||||||
return queueCapacity;
|
return queueCapacity;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TimeValue getTaskPollFrequency() {
|
||||||
|
return taskPollFrequency;
|
||||||
|
}
|
||||||
|
|
||||||
|
Duration getRateLimitGroupStaleDuration() {
|
||||||
|
return rateLimitGroupStaleDuration;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
@ -21,14 +20,17 @@ import java.util.function.Supplier;
|
||||||
* A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service.
|
* A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service.
|
||||||
*/
|
*/
|
||||||
public interface RequestManager extends RateLimitable {
|
public interface RequestManager extends RateLimitable {
|
||||||
Runnable create(
|
void execute(
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
RequestSender requestSender,
|
RequestSender requestSender,
|
||||||
Supplier<Boolean> hasRequestCompletedFunction,
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
HttpClientContext context,
|
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual
|
||||||
|
// managers to implement their own batching
|
||||||
|
// executePreparedRequest() which will execute all prepared requests aka sends the batch
|
||||||
|
|
||||||
String inferenceEntityId();
|
String inferenceEntityId();
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,7 +111,7 @@ class RequestTask implements RejectableTask {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RequestManager getRequestCreator() {
|
public RequestManager getRequestManager() {
|
||||||
return requestCreator;
|
return requestCreator;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,48 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
||||||
* or more contributor license agreements. Licensed under the Elastic License
|
|
||||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
|
||||||
* 2.0.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles executing a single inference request at a time.
|
|
||||||
*/
|
|
||||||
public class SingleRequestManager {
|
|
||||||
|
|
||||||
protected RetryingHttpSender requestSender;
|
|
||||||
|
|
||||||
public SingleRequestManager(RetryingHttpSender requestSender) {
|
|
||||||
this.requestSender = Objects.requireNonNull(requestSender);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void execute(InferenceRequest inferenceRequest, HttpClientContext context) {
|
|
||||||
if (isNoopRequest(inferenceRequest) || inferenceRequest.hasCompleted()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
inferenceRequest.getRequestCreator()
|
|
||||||
.create(
|
|
||||||
inferenceRequest.getQuery(),
|
|
||||||
inferenceRequest.getInput(),
|
|
||||||
requestSender,
|
|
||||||
inferenceRequest.getRequestCompletedFunction(),
|
|
||||||
context,
|
|
||||||
inferenceRequest.getListener()
|
|
||||||
)
|
|
||||||
.run();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
|
|
||||||
return inferenceRequest.getRequestCreator() == null
|
|
||||||
|| inferenceRequest.getInput() == null
|
|
||||||
|| inferenceRequest.getListener() == null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService {
|
||||||
|
|
||||||
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||||
Objects.requireNonNull(factory);
|
Objects.requireNonNull(factory);
|
||||||
sender = factory.createSender(name());
|
sender = factory.createSender();
|
||||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completio
|
||||||
|
|
||||||
public class AzureAiStudioService extends SenderService {
|
public class AzureAiStudioService extends SenderService {
|
||||||
|
|
||||||
private static final String NAME = "azureaistudio";
|
static final String NAME = "azureaistudio";
|
||||||
|
|
||||||
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||||
super(factory, serviceComponents);
|
super(factory, serviceComponents);
|
||||||
|
|
|
@ -44,7 +44,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
|
||||||
ConfigurationParseContext context
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
AzureAiStudioService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
AzureAiStudioEndpointType endpointType = extractRequiredEnum(
|
AzureAiStudioEndpointType endpointType = extractRequiredEnum(
|
||||||
map,
|
map,
|
||||||
ENDPOINT_TYPE_FIELD,
|
ENDPOINT_TYPE_FIELD,
|
||||||
|
@ -118,13 +124,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
|
||||||
|
|
||||||
protected void addXContentFields(XContentBuilder builder, Params params) throws IOException {
|
protected void addXContentFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
this.addExposedXContentFields(builder, params);
|
this.addExposedXContentFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException {
|
protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.field(TARGET_FIELD, this.target);
|
builder.field(TARGET_FIELD, this.target);
|
||||||
builder.field(PROVIDER_FIELD, this.provider);
|
builder.field(PROVIDER_FIELD, this.provider);
|
||||||
builder.field(ENDPOINT_TYPE_FIELD, this.endpointType);
|
builder.field(ENDPOINT_TYPE_FIELD, this.endpointType);
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,7 +135,15 @@ public class AzureOpenAiService extends SenderService {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
case COMPLETION -> {
|
case COMPLETION -> {
|
||||||
return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
|
return new AzureOpenAiCompletionModel(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
NAME,
|
||||||
|
serviceSettings,
|
||||||
|
taskSettings,
|
||||||
|
secretSettings,
|
||||||
|
context
|
||||||
|
);
|
||||||
}
|
}
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
|
||||||
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
|
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
|
||||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
|
||||||
|
|
||||||
|
@ -37,13 +38,14 @@ public class AzureOpenAiCompletionModel extends AzureOpenAiModel {
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
@Nullable Map<String, Object> secrets
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings),
|
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||||
AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
|
AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
|
||||||
AzureOpenAiSecretSettings.fromMap(secrets)
|
AzureOpenAiSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
|
|
|
@ -17,7 +17,9 @@ import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.xcontent.ToXContent;
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -55,10 +57,10 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
||||||
*/
|
*/
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
|
||||||
|
|
||||||
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map) {
|
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
var settings = fromMap(map, validationException);
|
var settings = fromMap(map, validationException, context);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -69,12 +71,19 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
||||||
|
|
||||||
private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap(
|
private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap(
|
||||||
Map<String, Object> map,
|
Map<String, Object> map,
|
||||||
ValidationException validationException
|
ValidationException validationException,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
AzureOpenAiService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings);
|
return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings);
|
||||||
}
|
}
|
||||||
|
@ -137,7 +146,6 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -148,6 +156,7 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
||||||
builder.field(RESOURCE_NAME, resourceName);
|
builder.field(RESOURCE_NAME, resourceName);
|
||||||
builder.field(DEPLOYMENT_ID, deploymentId);
|
builder.field(DEPLOYMENT_ID, deploymentId);
|
||||||
builder.field(API_VERSION, apiVersion);
|
builder.field(API_VERSION, apiVersion);
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -90,7 +91,13 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
|
||||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||||
Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||||
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
AzureOpenAiService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
|
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
|
||||||
|
|
||||||
|
@ -245,8 +252,6 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
|
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
|
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
|
@ -268,6 +273,7 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
|
||||||
if (similarity != null) {
|
if (similarity != null) {
|
||||||
builder.field(SIMILARITY, similarity);
|
builder.field(SIMILARITY, similarity);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,11 @@ import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFie
|
||||||
public class CohereService extends SenderService {
|
public class CohereService extends SenderService {
|
||||||
public static final String NAME = "cohere";
|
public static final String NAME = "cohere";
|
||||||
|
|
||||||
|
// TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to
|
||||||
|
// the Cohere*RequestManager via the CohereActionCreator class
|
||||||
|
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
|
||||||
|
// on every request
|
||||||
|
|
||||||
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||||
super(factory, serviceComponents);
|
super(factory, serviceComponents);
|
||||||
}
|
}
|
||||||
|
@ -131,7 +136,15 @@ public class CohereService extends SenderService {
|
||||||
context
|
context
|
||||||
);
|
);
|
||||||
case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
|
case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
|
||||||
case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
|
case COMPLETION -> new CohereCompletionModel(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
NAME,
|
||||||
|
serviceSettings,
|
||||||
|
taskSettings,
|
||||||
|
secretSettings,
|
||||||
|
context
|
||||||
|
);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,13 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
|
||||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
CohereService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
|
||||||
|
@ -173,10 +179,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
|
||||||
}
|
}
|
||||||
|
|
||||||
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
return toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
return builder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -196,6 +199,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
|
||||||
if (modelId != null) {
|
if (modelId != null) {
|
||||||
builder.field(MODEL_ID, modelId);
|
builder.field(MODEL_ID, modelId);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskSettings;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
|
import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
@ -30,13 +31,14 @@ public class CohereCompletionModel extends CohereModel {
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
@Nullable Map<String, Object> secrets
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
modelId,
|
modelId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
CohereCompletionServiceSettings.fromMap(serviceSettings),
|
CohereCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||||
EmptyTaskSettings.INSTANCE,
|
EmptyTaskSettings.INSTANCE,
|
||||||
DefaultSecretSettings.fromMap(secrets)
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
|
|
|
@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -39,12 +41,18 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
|
||||||
// 10K requests per minute
|
// 10K requests per minute
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
|
||||||
|
|
||||||
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map) {
|
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
CohereService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
@ -94,7 +102,6 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -127,6 +134,7 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
|
||||||
if (modelId != null) {
|
if (modelId != null) {
|
||||||
builder.field(MODEL_ID, modelId);
|
builder.field(MODEL_ID, modelId);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,8 @@ public class GoogleAiStudioService extends SenderService {
|
||||||
NAME,
|
NAME,
|
||||||
serviceSettings,
|
serviceSettings,
|
||||||
taskSettings,
|
taskSettings,
|
||||||
secretSettings
|
secretSettings,
|
||||||
|
context
|
||||||
);
|
);
|
||||||
case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel(
|
case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
|
@ -116,7 +117,8 @@ public class GoogleAiStudioService extends SenderService {
|
||||||
NAME,
|
NAME,
|
||||||
serviceSettings,
|
serviceSettings,
|
||||||
taskSettings,
|
taskSettings,
|
||||||
secretSettings
|
secretSettings,
|
||||||
|
context
|
||||||
);
|
);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
|
||||||
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
|
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
|
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
@ -37,13 +38,14 @@ public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel {
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
Map<String, Object> secrets
|
Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings),
|
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||||
EmptyTaskSettings.INSTANCE,
|
EmptyTaskSettings.INSTANCE,
|
||||||
DefaultSecretSettings.fromMap(secrets)
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
|
|
|
@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -40,11 +42,17 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
|
||||||
*/
|
*/
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
|
||||||
|
|
||||||
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map) {
|
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
GoogleAiStudioService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -82,7 +90,6 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -107,6 +114,7 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
|
||||||
@Override
|
@Override
|
||||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.field(MODEL_ID, modelId);
|
builder.field(MODEL_ID, modelId);
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
|
||||||
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
|
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
|
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
@ -37,13 +38,14 @@ public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel {
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
Map<String, Object> secrets
|
Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings),
|
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
||||||
EmptyTaskSettings.INSTANCE,
|
EmptyTaskSettings.INSTANCE,
|
||||||
DefaultSecretSettings.fromMap(secrets)
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
|
|
|
@ -18,7 +18,9 @@ import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.inference.SimilarityMeasure;
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -47,7 +49,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
||||||
*/
|
*/
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
|
||||||
|
|
||||||
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map) {
|
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
@ -59,7 +61,13 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
||||||
);
|
);
|
||||||
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
GoogleAiStudioService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -134,7 +142,6 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -174,6 +181,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
||||||
if (similarity != null) {
|
if (similarity != null) {
|
||||||
builder.field(SIMILARITY, similarity);
|
builder.field(SIMILARITY, similarity);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
||||||
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
|
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.SenderService;
|
import org.elasticsearch.xpack.inference.services.SenderService;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
|
|
||||||
|
@ -62,7 +63,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, name())
|
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
|
||||||
|
ConfigurationParseContext.REQUEST
|
||||||
);
|
);
|
||||||
|
|
||||||
throwIfNotEmptyMap(config, name());
|
throwIfNotEmptyMap(config, name());
|
||||||
|
@ -89,7 +91,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
secretSettingsMap,
|
secretSettingsMap,
|
||||||
parsePersistedConfigErrorMsg(inferenceEntityId, name())
|
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +100,14 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
||||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||||
|
|
||||||
return createModel(inferenceEntityId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(inferenceEntityId, name()));
|
return createModel(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
serviceSettingsMap,
|
||||||
|
null,
|
||||||
|
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract HuggingFaceModel createModel(
|
protected abstract HuggingFaceModel createModel(
|
||||||
|
@ -105,7 +115,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> secretSettings,
|
Map<String, Object> secretSettings,
|
||||||
String failureMessage
|
String failureMessage,
|
||||||
|
ConfigurationParseContext context
|
||||||
);
|
);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.inference.Model;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
|
@ -36,11 +37,19 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
@Nullable Map<String, Object> secretSettings,
|
@Nullable Map<String, Object> secretSettings,
|
||||||
String failureMessage
|
String failureMessage,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
return switch (taskType) {
|
return switch (taskType) {
|
||||||
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
|
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
|
||||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
NAME,
|
||||||
|
serviceSettings,
|
||||||
|
secretSettings,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.inference.SimilarityMeasure;
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -43,14 +44,20 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
|
||||||
// 3000 requests per minute
|
// 3000 requests per minute
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||||
|
|
||||||
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map) {
|
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
var uri = extractUri(map, URL, validationException);
|
var uri = extractUri(map, URL, validationException);
|
||||||
|
|
||||||
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
HuggingFaceService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -119,7 +126,6 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -136,6 +142,7 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
|
||||||
if (maxInputTokens != null) {
|
if (maxInputTokens != null) {
|
||||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
@ -24,13 +25,14 @@ public class HuggingFaceElserModel extends HuggingFaceModel {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
@Nullable Map<String, Object> secrets
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
HuggingFaceElserServiceSettings.fromMap(serviceSettings),
|
HuggingFaceElserServiceSettings.fromMap(serviceSettings, context),
|
||||||
DefaultSecretSettings.fromMap(secrets)
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||||
|
@ -38,10 +39,11 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
@Nullable Map<String, Object> secretSettings,
|
@Nullable Map<String, Object> secretSettings,
|
||||||
String failureMessage
|
String failureMessage,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
return switch (taskType) {
|
return switch (taskType) {
|
||||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
|
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,9 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -40,10 +42,16 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
|
||||||
// 3000 requests per minute
|
// 3000 requests per minute
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||||
|
|
||||||
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map) {
|
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
var uri = extractUri(map, URL, validationException);
|
var uri = extractUri(map, URL, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
HuggingFaceService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -93,7 +101,6 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -103,6 +110,7 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
|
||||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.field(URL, uri.toString());
|
builder.field(URL, uri.toString());
|
||||||
builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT);
|
builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT);
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
@ -25,13 +26,14 @@ public class HuggingFaceEmbeddingsModel extends HuggingFaceModel {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
@Nullable Map<String, Object> secrets
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
HuggingFaceServiceSettings.fromMap(serviceSettings),
|
HuggingFaceServiceSettings.fromMap(serviceSettings, context),
|
||||||
DefaultSecretSettings.fromMap(secrets)
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.inference.SimilarityMeasure;
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -59,7 +60,13 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
|
||||||
ModelConfigurations.SERVICE_SETTINGS,
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
validationException
|
validationException
|
||||||
);
|
);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
MistralService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
@ -141,7 +148,6 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
this.toXContentFragmentOfExposedFields(builder, params);
|
this.toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -159,6 +165,7 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
|
||||||
if (this.maxInputTokens != null) {
|
if (this.maxInputTokens != null) {
|
||||||
builder.field(MAX_INPUT_TOKENS, this.maxInputTokens);
|
builder.field(MAX_INPUT_TOKENS, this.maxInputTokens);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,7 +138,8 @@ public class OpenAiService extends SenderService {
|
||||||
NAME,
|
NAME,
|
||||||
serviceSettings,
|
serviceSettings,
|
||||||
taskSettings,
|
taskSettings,
|
||||||
secretSettings
|
secretSettings,
|
||||||
|
context
|
||||||
);
|
);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor;
|
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiModel;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiModel;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
@ -35,13 +36,14 @@ public class OpenAiChatCompletionModel extends OpenAiModel {
|
||||||
String service,
|
String service,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
@Nullable Map<String, Object> secrets
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
this(
|
this(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings),
|
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||||
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
|
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
|
||||||
DefaultSecretSettings.fromMap(secrets)
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
);
|
);
|
||||||
|
|
|
@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -47,7 +49,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
||||||
// 500 requests per minute
|
// 500 requests per minute
|
||||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
|
||||||
|
|
||||||
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map) {
|
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
@ -58,7 +60,13 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
||||||
|
|
||||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||||
|
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
OpenAiService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -142,7 +150,6 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -163,6 +170,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
||||||
if (maxInputTokens != null) {
|
if (maxInputTokens != null) {
|
||||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -66,7 +67,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
||||||
// passed at that time and never throw.
|
// passed at that time and never throw.
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
var commonFields = fromMap(map, validationException);
|
var commonFields = fromMap(map, validationException, ConfigurationParseContext.PERSISTENT);
|
||||||
|
|
||||||
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
|
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
|
||||||
if (dimensionsSetByUser == null) {
|
if (dimensionsSetByUser == null) {
|
||||||
|
@ -80,7 +81,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
||||||
private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map) {
|
private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map) {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
var commonFields = fromMap(map, validationException);
|
var commonFields = fromMap(map, validationException, ConfigurationParseContext.REQUEST);
|
||||||
|
|
||||||
if (validationException.validationErrors().isEmpty() == false) {
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
throw validationException;
|
throw validationException;
|
||||||
|
@ -89,7 +90,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
||||||
return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null);
|
return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static CommonFields fromMap(Map<String, Object> map, ValidationException validationException) {
|
private static CommonFields fromMap(
|
||||||
|
Map<String, Object> map,
|
||||||
|
ValidationException validationException,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
|
|
||||||
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
@ -98,7 +103,13 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
||||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
OpenAiService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings);
|
return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings);
|
||||||
}
|
}
|
||||||
|
@ -258,7 +269,6 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
|
||||||
toXContentFragmentOfExposedFields(builder, params);
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
rateLimitSettings.toXContent(builder, params);
|
|
||||||
|
|
||||||
if (dimensionsSetByUser != null) {
|
if (dimensionsSetByUser != null) {
|
||||||
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
|
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
|
||||||
|
@ -286,6 +296,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
||||||
if (maxInputTokens != null) {
|
if (maxInputTokens != null) {
|
||||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||||
}
|
}
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.xcontent.ToXContentFragment;
|
import org.elasticsearch.xcontent.ToXContentFragment;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -21,19 +22,29 @@ import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||||
|
|
||||||
public class RateLimitSettings implements Writeable, ToXContentFragment {
|
public class RateLimitSettings implements Writeable, ToXContentFragment {
|
||||||
|
|
||||||
public static final String FIELD_NAME = "rate_limit";
|
public static final String FIELD_NAME = "rate_limit";
|
||||||
public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute";
|
public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute";
|
||||||
|
|
||||||
private final long requestsPerTimeUnit;
|
private final long requestsPerTimeUnit;
|
||||||
private final TimeUnit timeUnit;
|
private final TimeUnit timeUnit;
|
||||||
|
|
||||||
public static RateLimitSettings of(Map<String, Object> map, RateLimitSettings defaultValue, ValidationException validationException) {
|
public static RateLimitSettings of(
|
||||||
|
Map<String, Object> map,
|
||||||
|
RateLimitSettings defaultValue,
|
||||||
|
ValidationException validationException,
|
||||||
|
String serviceName,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
Map<String, Object> settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME);
|
Map<String, Object> settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME);
|
||||||
var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
|
var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
|
||||||
|
|
||||||
|
if (ConfigurationParseContext.isRequestContext(context)) {
|
||||||
|
throwIfNotEmptyMap(settings, serviceName);
|
||||||
|
}
|
||||||
|
|
||||||
return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute);
|
return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
|
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
|
||||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||||
|
@ -92,7 +93,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
||||||
TruncatorTests.createTruncator()
|
TruncatorTests.createTruncator()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
|
||||||
|
@ -141,7 +142,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
||||||
TruncatorTests.createTruncator()
|
TruncatorTests.createTruncator()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
|
||||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
||||||
|
@ -82,7 +83,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
|
public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -132,7 +133,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException {
|
public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -183,7 +184,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
// timeout as zero for no retries
|
// timeout as zero for no retries
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -237,7 +238,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
// note - there is no complete documentation on Azure's error messages
|
// note - there is no complete documentation on Azure's error messages
|
||||||
|
@ -313,7 +314,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
// note - there is no complete documentation on Azure's error messages
|
// note - there is no complete documentation on Azure's error messages
|
||||||
|
@ -389,7 +390,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -440,7 +441,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
|
public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -498,7 +499,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
|
public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -554,7 +555,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
||||||
// timeout as zero for no retries
|
// timeout as zero for no retries
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
// "choices" missing
|
// "choices" missing
|
||||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
|
import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
@ -77,7 +78,7 @@ public class AzureOpenAiCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
|
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
|
||||||
|
@ -81,7 +82,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
|
||||||
mockClusterServiceEmpty()
|
mockClusterServiceEmpty()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
|
@ -73,7 +74,7 @@ public class CohereActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_CohereEmbeddingsModel() throws IOException {
|
public void testCreate_CohereEmbeddingsModel() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -154,7 +155,7 @@ public class CohereActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
|
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -214,7 +215,7 @@ public class CohereActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
|
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class CohereCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -138,7 +138,7 @@ public class CohereCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -290,7 +290,7 @@ public class CohereCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -81,7 +81,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -162,7 +162,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -206,7 +206,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = senderFactory.createSender()) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
|
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.hamcrest.Matchers.contains;
|
import static org.hamcrest.Matchers.contains;
|
||||||
|
@ -75,7 +76,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -131,7 +132,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -187,7 +188,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -239,7 +240,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
// this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]]
|
// this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]]
|
||||||
|
@ -292,7 +293,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJsonContentTooLarge = """
|
String responseJsonContentTooLarge = """
|
||||||
|
@ -357,7 +358,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||||
|
@ -74,7 +75,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_OpenAiEmbeddingsModel() throws IOException {
|
public void testCreate_OpenAiEmbeddingsModel() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -127,7 +128,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException {
|
public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -179,7 +180,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException {
|
public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -238,7 +239,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -292,7 +293,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_OpenAiChatCompletionModel() throws IOException {
|
public void testCreate_OpenAiChatCompletionModel() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -355,7 +356,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException {
|
public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -417,7 +418,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException {
|
public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -486,7 +487,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -552,7 +553,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
var contentTooLargeErrorMessage =
|
var contentTooLargeErrorMessage =
|
||||||
|
@ -635,7 +636,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
var contentTooLargeErrorMessage =
|
var contentTooLargeErrorMessage =
|
||||||
|
@ -718,7 +719,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
||||||
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
|
@ -80,7 +81,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -234,7 +235,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
|
||||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
|
||||||
mockClusterServiceEmpty()
|
mockClusterServiceEmpty()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = senderFactory.createSender()) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
public class BaseRequestManagerTests extends ESTestCase {
|
||||||
|
public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() {
|
||||||
|
int val1 = 1;
|
||||||
|
int val2 = 1;
|
||||||
|
|
||||||
|
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
|
||||||
|
@Override
|
||||||
|
public void execute(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) {
|
||||||
|
@Override
|
||||||
|
public void execute(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
assertThat(manager1.rateLimitGrouping(), is(manager2.rateLimitGrouping()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() {
|
||||||
|
int val1 = 1;
|
||||||
|
|
||||||
|
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
|
||||||
|
@Override
|
||||||
|
public void execute(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) {
|
||||||
|
@Override
|
||||||
|
public void execute(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() {
|
||||||
|
int val1 = 1;
|
||||||
|
|
||||||
|
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) {
|
||||||
|
@Override
|
||||||
|
public void execute(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) {
|
||||||
|
@Override
|
||||||
|
public void execute(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
RequestSender requestSender,
|
||||||
|
Supplier<Boolean> hasRequestCompletedFunction,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
|
||||||
|
}
|
||||||
|
}
|
|
@ -79,7 +79,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
||||||
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
|
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
|
||||||
var senderFactory = createSenderFactory(clientManager, threadRef);
|
var senderFactory = createSenderFactory(clientManager, threadRef);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = createSender(senderFactory)) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
|
@ -135,11 +135,11 @@ public class HttpRequestSenderTests extends ESTestCase {
|
||||||
mockClusterServiceEmpty()
|
mockClusterServiceEmpty()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = senderFactory.createSender()) {
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
AssertionError.class,
|
AssertionError.class,
|
||||||
() -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
|
() -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
|
||||||
);
|
);
|
||||||
assertThat(thrownException.getMessage(), is("call start() before sending a request"));
|
assertThat(thrownException.getMessage(), is("call start() before sending a request"));
|
||||||
}
|
}
|
||||||
|
@ -155,17 +155,12 @@ public class HttpRequestSenderTests extends ESTestCase {
|
||||||
mockClusterServiceEmpty()
|
mockClusterServiceEmpty()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = senderFactory.createSender()) {
|
||||||
assertThat(sender, instanceOf(HttpRequestSender.class));
|
assertThat(sender, instanceOf(HttpRequestSender.class));
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
sender.send(
|
sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
|
||||||
ExecutableRequestCreatorTests.createMock(),
|
|
||||||
new DocumentsOnlyInput(List.of()),
|
|
||||||
TimeValue.timeValueNanos(1),
|
|
||||||
listener
|
|
||||||
);
|
|
||||||
|
|
||||||
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
|
@ -186,16 +181,11 @@ public class HttpRequestSenderTests extends ESTestCase {
|
||||||
mockClusterServiceEmpty()
|
mockClusterServiceEmpty()
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var sender = senderFactory.createSender("test_service")) {
|
try (var sender = senderFactory.createSender()) {
|
||||||
sender.start();
|
sender.start();
|
||||||
|
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
sender.send(
|
sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
|
||||||
ExecutableRequestCreatorTests.createMock(),
|
|
||||||
new DocumentsOnlyInput(List.of()),
|
|
||||||
TimeValue.timeValueNanos(1),
|
|
||||||
listener
|
|
||||||
);
|
|
||||||
|
|
||||||
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
|
@ -220,6 +210,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
||||||
when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService);
|
when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService);
|
||||||
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||||
when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class));
|
when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class));
|
||||||
|
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
|
||||||
|
|
||||||
return new HttpRequestSender.Factory(
|
return new HttpRequestSender.Factory(
|
||||||
ServiceComponentsTests.createWithEmptySettings(mockThreadPool),
|
ServiceComponentsTests.createWithEmptySettings(mockThreadPool),
|
||||||
|
@ -248,7 +239,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) {
|
public static Sender createSender(HttpRequestSender.Factory factory) {
|
||||||
return factory.createSender(serviceName);
|
return factory.createSender();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
|
||||||
|
|
||||||
|
@ -18,12 +19,23 @@ public class RequestExecutorServiceSettingsTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) {
|
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) {
|
||||||
|
return createRequestExecutorServiceSettings(queueCapacity, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(
|
||||||
|
@Nullable Integer queueCapacity,
|
||||||
|
@Nullable TimeValue staleDuration
|
||||||
|
) {
|
||||||
var settingsBuilder = Settings.builder();
|
var settingsBuilder = Settings.builder();
|
||||||
|
|
||||||
if (queueCapacity != null) {
|
if (queueCapacity != null) {
|
||||||
settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity);
|
settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (staleDuration != null) {
|
||||||
|
settingsBuilder.put(RequestExecutorServiceSettings.RATE_LIMIT_GROUP_STALE_DURATION_SETTING.getKey(), staleDuration);
|
||||||
|
}
|
||||||
|
|
||||||
return createRequestExecutorServiceSettings(settingsBuilder.build());
|
return createRequestExecutorServiceSettings(settingsBuilder.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,13 +18,19 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.threadpool.Scheduler;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.inference.common.RateLimiter;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.mockito.ArgumentCaptor;
|
import org.mockito.ArgumentCaptor;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.BlockingQueue;
|
import java.util.concurrent.BlockingQueue;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
@ -42,10 +48,13 @@ import static org.elasticsearch.xpack.inference.external.http.sender.RequestExec
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
import static org.mockito.Mockito.doAnswer;
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class RequestExecutorServiceTests extends ESTestCase {
|
public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
|
@ -70,7 +79,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testQueueSize_IsOne() {
|
public void testQueueSize_IsOne() {
|
||||||
var service = createRequestExecutorServiceWithMocks();
|
var service = createRequestExecutorServiceWithMocks();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||||
|
|
||||||
assertThat(service.queueSize(), is(1));
|
assertThat(service.queueSize(), is(1));
|
||||||
}
|
}
|
||||||
|
@ -92,7 +101,20 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testIsTerminated_AfterStopFromSeparateThread() throws Exception {
|
public void testCallingStartTwice_ThrowsAssertionException() throws InterruptedException {
|
||||||
|
var latch = new CountDownLatch(1);
|
||||||
|
var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class));
|
||||||
|
|
||||||
|
service.shutdown();
|
||||||
|
service.start();
|
||||||
|
latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
assertTrue(service.isTerminated());
|
||||||
|
var exception = expectThrows(AssertionError.class, service::start);
|
||||||
|
assertThat(exception.getMessage(), is("start() can only be called once"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testIsTerminated_AfterStopFromSeparateThread() {
|
||||||
var waitToShutdown = new CountDownLatch(1);
|
var waitToShutdown = new CountDownLatch(1);
|
||||||
var waitToReturnFromSend = new CountDownLatch(1);
|
var waitToReturnFromSend = new CountDownLatch(1);
|
||||||
|
|
||||||
|
@ -127,41 +149,48 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSend_AfterShutdown_Throws() {
|
public void testExecute_AfterShutdown_Throws() {
|
||||||
var service = createRequestExecutorServiceWithMocks();
|
var service = createRequestExecutorServiceWithMocks();
|
||||||
|
|
||||||
service.shutdown();
|
service.shutdown();
|
||||||
|
|
||||||
|
var requestManager = RequestManagerTests.createMock("id");
|
||||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is("Failed to enqueue task because the http executor service [test_service] has already shutdown")
|
is(
|
||||||
|
Strings.format(
|
||||||
|
"Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown",
|
||||||
|
requestManager.rateLimitGrouping().hashCode()
|
||||||
|
)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
assertTrue(thrownException.isExecutorShutdown());
|
assertTrue(thrownException.isExecutorShutdown());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSend_Throws_WhenQueueIsFull() {
|
public void testExecute_Throws_WhenQueueIsFull() {
|
||||||
var service = new RequestExecutorService(
|
var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class));
|
||||||
"test_service",
|
|
||||||
threadPool,
|
|
||||||
null,
|
|
||||||
createRequestExecutorServiceSettings(1),
|
|
||||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
|
||||||
);
|
|
||||||
|
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||||
|
|
||||||
|
var requestManager = RequestManagerTests.createMock("id");
|
||||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is("Failed to execute task because the http executor service [test_service] queue is full")
|
is(
|
||||||
|
Strings.format(
|
||||||
|
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
|
||||||
|
requestManager.rateLimitGrouping().hashCode()
|
||||||
|
)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
assertFalse(thrownException.isExecutorShutdown());
|
assertFalse(thrownException.isExecutorShutdown());
|
||||||
}
|
}
|
||||||
|
@ -203,16 +232,11 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
assertTrue(service.isShutdown());
|
assertTrue(service.isShutdown());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSend_CallsOnFailure_WhenRequestTimesOut() {
|
public void testExecute_CallsOnFailure_WhenRequestTimesOut() {
|
||||||
var service = createRequestExecutorServiceWithMocks();
|
var service = createRequestExecutorServiceWithMocks();
|
||||||
|
|
||||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||||
service.execute(
|
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
|
||||||
ExecutableRequestCreatorTests.createMock(),
|
|
||||||
new DocumentsOnlyInput(List.of()),
|
|
||||||
TimeValue.timeValueNanos(1),
|
|
||||||
listener
|
|
||||||
);
|
|
||||||
|
|
||||||
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
|
@ -222,7 +246,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
|
public void testExecute_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
var headerKey = "not empty";
|
var headerKey = "not empty";
|
||||||
var headerValue = "value";
|
var headerValue = "value";
|
||||||
|
|
||||||
|
@ -270,7 +294,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
Future<?> executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service);
|
Future<?> executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service);
|
||||||
|
|
||||||
|
@ -280,11 +304,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSend_NotifiesTasksOfShutdown() {
|
public void testExecute_NotifiesTasksOfShutdown() {
|
||||||
var service = createRequestExecutorServiceWithMocks();
|
var service = createRequestExecutorServiceWithMocks();
|
||||||
|
|
||||||
|
var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id");
|
||||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
service.shutdown();
|
service.shutdown();
|
||||||
service.start();
|
service.start();
|
||||||
|
@ -293,47 +318,62 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
|
is(
|
||||||
|
Strings.format(
|
||||||
|
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
|
||||||
|
requestManager.rateLimitGrouping().hashCode()
|
||||||
|
)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
assertTrue(thrownException.isExecutorShutdown());
|
assertTrue(thrownException.isExecutorShutdown());
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
|
public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
||||||
|
|
||||||
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
|
||||||
var service = new RequestExecutorService(
|
var service = new RequestExecutorService(
|
||||||
getTestName(),
|
|
||||||
threadPool,
|
threadPool,
|
||||||
mockQueueCreator(queue),
|
mockQueueCreator(queue),
|
||||||
null,
|
null,
|
||||||
createRequestExecutorServiceSettingsEmpty(),
|
createRequestExecutorServiceSettingsEmpty(),
|
||||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
requestSender,
|
||||||
|
Clock.systemUTC(),
|
||||||
|
RequestExecutorService.DEFAULT_SLEEPER,
|
||||||
|
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||||
);
|
);
|
||||||
|
|
||||||
when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
var requestManager = RequestManagerTests.createMock(requestSender, "id");
|
||||||
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
|
when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
|
||||||
service.shutdown();
|
service.shutdown();
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
service.start();
|
service.start();
|
||||||
|
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
verify(queue, times(2)).take();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception {
|
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
||||||
when(queue.take()).thenThrow(new InterruptedException("failed"));
|
var sleeper = mock(RequestExecutorService.Sleeper.class);
|
||||||
|
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
|
||||||
|
|
||||||
var service = new RequestExecutorService(
|
var service = new RequestExecutorService(
|
||||||
getTestName(),
|
|
||||||
threadPool,
|
threadPool,
|
||||||
mockQueueCreator(queue),
|
mockQueueCreator(queue),
|
||||||
null,
|
null,
|
||||||
createRequestExecutorServiceSettingsEmpty(),
|
createRequestExecutorServiceSettingsEmpty(),
|
||||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
mock(RetryingHttpSender.class),
|
||||||
|
Clock.systemUTC(),
|
||||||
|
sleeper,
|
||||||
|
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||||
);
|
);
|
||||||
|
|
||||||
Future<?> executorTermination = threadPool.generic().submit(() -> {
|
Future<?> executorTermination = threadPool.generic().submit(() -> {
|
||||||
|
@ -347,66 +387,30 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||||
|
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
verify(queue, times(1)).take();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception {
|
|
||||||
var mockTask = mock(RejectableTask.class);
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
|
||||||
|
|
||||||
var service = new RequestExecutorService(
|
|
||||||
"test_service",
|
|
||||||
threadPool,
|
|
||||||
mockQueueCreator(queue),
|
|
||||||
null,
|
|
||||||
createRequestExecutorServiceSettingsEmpty(),
|
|
||||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
|
||||||
);
|
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
|
||||||
service.shutdown();
|
|
||||||
return mockTask;
|
|
||||||
}).doReturn(new NoopTask()).when(queue).take();
|
|
||||||
|
|
||||||
service.start();
|
|
||||||
|
|
||||||
assertTrue(service.isTerminated());
|
|
||||||
verify(queue, times(1)).take();
|
|
||||||
|
|
||||||
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
|
|
||||||
verify(mockTask, times(1)).onRejection(argument.capture());
|
|
||||||
assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class));
|
|
||||||
assertThat(
|
|
||||||
argument.getValue().getMessage(),
|
|
||||||
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
|
|
||||||
);
|
|
||||||
|
|
||||||
var rejectionException = (EsRejectedExecutionException) argument.getValue();
|
|
||||||
assertTrue(rejectionException.isExecutorShutdown());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException {
|
public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException {
|
||||||
var requestSender = mock(RetryingHttpSender.class);
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
|
||||||
var settings = createRequestExecutorServiceSettings(1);
|
var settings = createRequestExecutorServiceSettings(1);
|
||||||
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
|
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
|
||||||
|
|
||||||
service.execute(
|
service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
|
||||||
new DocumentsOnlyInput(List.of()),
|
|
||||||
null,
|
|
||||||
new PlainActionFuture<>()
|
|
||||||
);
|
|
||||||
assertThat(service.queueSize(), is(1));
|
assertThat(service.queueSize(), is(1));
|
||||||
|
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
var requestManager = RequestManagerTests.createMock(requestSender, "id");
|
||||||
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is("Failed to execute task because the http executor service [test_service] queue is full")
|
is(
|
||||||
|
Strings.format(
|
||||||
|
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
|
||||||
|
requestManager.rateLimitGrouping().hashCode()
|
||||||
|
)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
settings.setQueueCapacity(2);
|
settings.setQueueCapacity(2);
|
||||||
|
@ -426,7 +430,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
|
|
||||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
assertThat(service.remainingQueueCapacity(), is(2));
|
assertThat(service.remainingQueueCapacity(requestManager), is(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException,
|
public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException,
|
||||||
|
@ -434,23 +438,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
var requestSender = mock(RetryingHttpSender.class);
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
|
||||||
var settings = createRequestExecutorServiceSettings(3);
|
var settings = createRequestExecutorServiceSettings(3);
|
||||||
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
|
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
|
||||||
|
|
||||||
service.execute(
|
service.execute(
|
||||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
RequestManagerTests.createMock(requestSender, "id"),
|
||||||
new DocumentsOnlyInput(List.of()),
|
new DocumentsOnlyInput(List.of()),
|
||||||
null,
|
null,
|
||||||
new PlainActionFuture<>()
|
new PlainActionFuture<>()
|
||||||
);
|
);
|
||||||
service.execute(
|
service.execute(
|
||||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
RequestManagerTests.createMock(requestSender, "id"),
|
||||||
new DocumentsOnlyInput(List.of()),
|
new DocumentsOnlyInput(List.of()),
|
||||||
null,
|
null,
|
||||||
new PlainActionFuture<>()
|
new PlainActionFuture<>()
|
||||||
);
|
);
|
||||||
|
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
var requestManager = RequestManagerTests.createMock(requestSender, "id");
|
||||||
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
assertThat(service.queueSize(), is(3));
|
assertThat(service.queueSize(), is(3));
|
||||||
|
|
||||||
settings.setQueueCapacity(1);
|
settings.setQueueCapacity(1);
|
||||||
|
@ -470,7 +475,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
|
|
||||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
assertThat(service.remainingQueueCapacity(), is(1));
|
assertThat(service.remainingQueueCapacity(requestManager), is(1));
|
||||||
assertThat(service.queueSize(), is(0));
|
assertThat(service.queueSize(), is(0));
|
||||||
|
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
|
@ -479,7 +484,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
|
is(
|
||||||
|
Strings.format(
|
||||||
|
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
|
||||||
|
requestManager.rateLimitGrouping().hashCode()
|
||||||
|
)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
assertTrue(thrownException.isExecutorShutdown());
|
assertTrue(thrownException.isExecutorShutdown());
|
||||||
}
|
}
|
||||||
|
@ -489,23 +499,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
var requestSender = mock(RetryingHttpSender.class);
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
|
||||||
var settings = createRequestExecutorServiceSettings(1);
|
var settings = createRequestExecutorServiceSettings(1);
|
||||||
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
|
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
|
||||||
|
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||||
|
|
||||||
service.execute(
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
|
||||||
new DocumentsOnlyInput(List.of()),
|
|
||||||
null,
|
|
||||||
new PlainActionFuture<>()
|
|
||||||
);
|
|
||||||
assertThat(service.queueSize(), is(1));
|
assertThat(service.queueSize(), is(1));
|
||||||
|
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
service.execute(RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is("Failed to execute task because the http executor service [test_service] queue is full")
|
is(
|
||||||
|
Strings.format(
|
||||||
|
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
|
||||||
|
requestManager.rateLimitGrouping().hashCode()
|
||||||
|
)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
settings.setQueueCapacity(0);
|
settings.setQueueCapacity(0);
|
||||||
|
@ -525,7 +536,133 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
|
|
||||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||||
assertTrue(service.isTerminated());
|
assertTrue(service.isTerminated());
|
||||||
assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE));
|
assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
|
||||||
|
var mockRateLimiter = mock(RateLimiter.class);
|
||||||
|
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
|
||||||
|
|
||||||
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
var settings = createRequestExecutorServiceSettings(1);
|
||||||
|
var service = new RequestExecutorService(
|
||||||
|
threadPool,
|
||||||
|
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||||
|
null,
|
||||||
|
settings,
|
||||||
|
requestSender,
|
||||||
|
Clock.systemUTC(),
|
||||||
|
RequestExecutorService.DEFAULT_SLEEPER,
|
||||||
|
rateLimiterCreator
|
||||||
|
);
|
||||||
|
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
|
doAnswer(invocation -> {
|
||||||
|
service.shutdown();
|
||||||
|
return TimeValue.timeValueDays(1);
|
||||||
|
}).when(mockRateLimiter).timeToReserve(anyInt());
|
||||||
|
|
||||||
|
service.start();
|
||||||
|
|
||||||
|
verifyNoInteractions(requestSender);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() {
|
||||||
|
var mockRateLimiter = mock(RateLimiter.class);
|
||||||
|
when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0));
|
||||||
|
|
||||||
|
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
|
||||||
|
|
||||||
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
var settings = createRequestExecutorServiceSettings(1);
|
||||||
|
var service = new RequestExecutorService(
|
||||||
|
threadPool,
|
||||||
|
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||||
|
null,
|
||||||
|
settings,
|
||||||
|
requestSender,
|
||||||
|
Clock.systemUTC(),
|
||||||
|
RequestExecutorService.DEFAULT_SLEEPER,
|
||||||
|
rateLimiterCreator
|
||||||
|
);
|
||||||
|
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
|
when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0));
|
||||||
|
|
||||||
|
doAnswer(invocation -> {
|
||||||
|
service.shutdown();
|
||||||
|
return Void.TYPE;
|
||||||
|
}).when(requestSender).send(any(), any(), any(), any(), any(), any());
|
||||||
|
|
||||||
|
service.start();
|
||||||
|
|
||||||
|
verify(requestSender, times(1)).send(any(), any(), any(), any(), any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRemovesRateLimitGroup_AfterStaleDuration() {
|
||||||
|
var now = Instant.now();
|
||||||
|
var clock = mock(Clock.class);
|
||||||
|
when(clock.instant()).thenReturn(now);
|
||||||
|
|
||||||
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
|
||||||
|
var service = new RequestExecutorService(
|
||||||
|
threadPool,
|
||||||
|
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||||
|
null,
|
||||||
|
settings,
|
||||||
|
requestSender,
|
||||||
|
clock,
|
||||||
|
RequestExecutorService.DEFAULT_SLEEPER,
|
||||||
|
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||||
|
);
|
||||||
|
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
|
assertThat(service.numberOfRateLimitGroups(), is(1));
|
||||||
|
// the time is moved to after the stale duration, so now we should remove this grouping
|
||||||
|
when(clock.instant()).thenReturn(now.plus(Duration.ofDays(2)));
|
||||||
|
service.removeStaleGroupings();
|
||||||
|
assertThat(service.numberOfRateLimitGroups(), is(0));
|
||||||
|
|
||||||
|
var requestManager2 = RequestManagerTests.createMock(requestSender, "id2");
|
||||||
|
service.execute(requestManager2, new DocumentsOnlyInput(List.of()), null, listener);
|
||||||
|
|
||||||
|
assertThat(service.numberOfRateLimitGroups(), is(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testStartsCleanupThread() {
|
||||||
|
var mockThreadPool = mock(ThreadPool.class);
|
||||||
|
|
||||||
|
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
|
||||||
|
|
||||||
|
var requestSender = mock(RetryingHttpSender.class);
|
||||||
|
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
|
||||||
|
var service = new RequestExecutorService(
|
||||||
|
mockThreadPool,
|
||||||
|
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||||
|
null,
|
||||||
|
settings,
|
||||||
|
requestSender,
|
||||||
|
Clock.systemUTC(),
|
||||||
|
RequestExecutorService.DEFAULT_SLEEPER,
|
||||||
|
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||||
|
);
|
||||||
|
|
||||||
|
service.shutdown();
|
||||||
|
service.start();
|
||||||
|
|
||||||
|
ArgumentCaptor<TimeValue> argument = ArgumentCaptor.forClass(TimeValue.class);
|
||||||
|
verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), argument.capture(), any());
|
||||||
|
assertThat(argument.getValue(), is(TimeValue.timeValueDays(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Future<?> submitShutdownRequest(
|
private Future<?> submitShutdownRequest(
|
||||||
|
@ -552,12 +689,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) {
|
private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) {
|
||||||
return new RequestExecutorService(
|
return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender);
|
||||||
"test_service",
|
|
||||||
threadPool,
|
|
||||||
startupLatch,
|
|
||||||
createRequestExecutorServiceSettingsEmpty(),
|
|
||||||
new SingleRequestManager(requestSender)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||||
import org.elasticsearch.xpack.inference.external.request.RequestTests;
|
import org.elasticsearch.xpack.inference.external.request.RequestTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyList;
|
import static org.mockito.ArgumentMatchers.anyList;
|
||||||
|
@ -21,34 +22,47 @@ import static org.mockito.Mockito.doAnswer;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class ExecutableRequestCreatorTests {
|
public class RequestManagerTests {
|
||||||
public static RequestManager createMock() {
|
public static RequestManager createMock() {
|
||||||
var mockCreator = mock(RequestManager.class);
|
return createMock(mock(RequestSender.class));
|
||||||
when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {});
|
}
|
||||||
|
|
||||||
return mockCreator;
|
public static RequestManager createMock(String inferenceEntityId) {
|
||||||
|
return createMock(mock(RequestSender.class), inferenceEntityId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static RequestManager createMock(RequestSender requestSender) {
|
public static RequestManager createMock(RequestSender requestSender) {
|
||||||
return createMock(requestSender, "id");
|
return createMock(requestSender, "id", new RateLimitSettings(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static RequestManager createMock(RequestSender requestSender, String modelId) {
|
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) {
|
||||||
var mockCreator = mock(RequestManager.class);
|
return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) {
|
||||||
|
var mockManager = mock(RequestManager.class);
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[5];
|
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[4];
|
||||||
return (Runnable) () -> requestSender.send(
|
requestSender.send(
|
||||||
mock(Logger.class),
|
mock(Logger.class),
|
||||||
RequestTests.mockRequest(modelId),
|
RequestTests.mockRequest(inferenceEntityId),
|
||||||
HttpClientContext.create(),
|
HttpClientContext.create(),
|
||||||
() -> false,
|
() -> false,
|
||||||
mock(ResponseHandler.class),
|
mock(ResponseHandler.class),
|
||||||
listener
|
listener
|
||||||
);
|
);
|
||||||
}).when(mockCreator).create(any(), anyList(), any(), any(), any(), any());
|
|
||||||
|
|
||||||
return mockCreator;
|
return Void.TYPE;
|
||||||
|
}).when(mockManager).execute(any(), anyList(), any(), any(), any());
|
||||||
|
|
||||||
|
// just return something consistent so the hashing works
|
||||||
|
when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId);
|
||||||
|
|
||||||
|
when(mockManager.rateLimitSettings()).thenReturn(settings);
|
||||||
|
when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId);
|
||||||
|
|
||||||
|
return mockManager;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,27 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
||||||
* or more contributor license agreements. Licensed under the Elastic License
|
|
||||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
|
||||||
* 2.0.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
|
||||||
|
|
||||||
import org.apache.http.client.protocol.HttpClientContext;
|
|
||||||
import org.elasticsearch.test.ESTestCase;
|
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
|
||||||
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class SingleRequestManagerTests extends ESTestCase {
|
|
||||||
public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() {
|
|
||||||
var requestCreator = mock(RequestManager.class);
|
|
||||||
var request = mock(InferenceRequest.class);
|
|
||||||
when(request.getRequestCreator()).thenReturn(requestCreator);
|
|
||||||
|
|
||||||
new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create());
|
|
||||||
verifyNoInteractions(requestCreator);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -59,7 +58,7 @@ public class SenderServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
|
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
|
||||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||||
|
@ -67,7 +66,7 @@ public class SenderServiceTests extends ESTestCase {
|
||||||
|
|
||||||
listener.actionGet(TIMEOUT);
|
listener.actionGet(TIMEOUT);
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
}
|
}
|
||||||
|
|
||||||
verify(sender, times(1)).close();
|
verify(sender, times(1)).close();
|
||||||
|
@ -79,7 +78,7 @@ public class SenderServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
|
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
|
||||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||||
|
@ -89,7 +88,7 @@ public class SenderServiceTests extends ESTestCase {
|
||||||
service.start(mock(Model.class), listener);
|
service.start(mock(Model.class), listener);
|
||||||
listener.actionGet(TIMEOUT);
|
listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(2)).start();
|
verify(sender, times(2)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,7 +76,6 @@ import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -819,7 +818,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -841,7 +840,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -112,7 +112,8 @@ public class AzureAiStudioChatCompletionServiceSettingsTests extends ESTestCase
|
||||||
String xContentResult = Strings.toString(builder);
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
assertThat(xContentResult, CoreMatchers.is("""
|
assertThat(xContentResult, CoreMatchers.is("""
|
||||||
{"target":"target_value","provider":"openai","endpoint_type":"token"}"""));
|
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
|
||||||
|
"rate_limit":{"requests_per_minute":3}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {
|
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {
|
||||||
|
|
|
@ -295,7 +295,7 @@ public class AzureAiStudioEmbeddingsServiceSettingsTests extends ESTestCase {
|
||||||
|
|
||||||
assertThat(xContentResult, CoreMatchers.is("""
|
assertThat(xContentResult, CoreMatchers.is("""
|
||||||
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
|
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
|
||||||
"dimensions":1024,"max_input_tokens":512}"""));
|
"rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static HashMap<String, Object> createRequestSettingsMap(
|
public static HashMap<String, Object> createRequestSettingsMap(
|
||||||
|
|
|
@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -594,7 +593,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -616,7 +615,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
|
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -46,7 +47,8 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
|
||||||
AzureOpenAiServiceFields.API_VERSION,
|
AzureOpenAiServiceFields.API_VERSION,
|
||||||
apiVersion
|
apiVersion
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null)));
|
assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null)));
|
||||||
|
@ -63,18 +65,6 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
|
||||||
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}"""));
|
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null);
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = entity.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() {
|
protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() {
|
||||||
return AzureOpenAiCompletionServiceSettings::new;
|
return AzureOpenAiCompletionServiceSettings::new;
|
||||||
|
|
|
@ -389,7 +389,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
|
||||||
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}"""));
|
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_RateLimit() throws IOException {
|
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser() throws IOException {
|
||||||
var entity = new AzureOpenAiEmbeddingsServiceSettings(
|
var entity = new AzureOpenAiEmbeddingsServiceSettings(
|
||||||
"resource",
|
"resource",
|
||||||
"deployment",
|
"deployment",
|
||||||
|
@ -408,7 +408,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
assertThat(xContentResult, is("""
|
||||||
{"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
|
{"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
|
||||||
"dimensions":1024,"max_input_tokens":512}"""));
|
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":1}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -613,7 +612,7 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -635,7 +634,7 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
@ -28,7 +29,8 @@ public class CohereCompletionModelTests extends ESTestCase {
|
||||||
"service",
|
"service",
|
||||||
new HashMap<>(Map.of()),
|
new HashMap<>(Map.of()),
|
||||||
new HashMap<>(Map.of("model", "overridden model")),
|
new HashMap<>(Map.of("model", "overridden model")),
|
||||||
null
|
null,
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||||
|
@ -34,7 +35,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
|
||||||
var model = "model";
|
var model = "model";
|
||||||
|
|
||||||
var serviceSettings = CohereCompletionServiceSettings.fromMap(
|
var serviceSettings = CohereCompletionServiceSettings.fromMap(
|
||||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model))
|
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null)));
|
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null)));
|
||||||
|
@ -55,7 +57,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
|
||||||
RateLimitSettings.FIELD_NAME,
|
RateLimitSettings.FIELD_NAME,
|
||||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute))));
|
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute))));
|
||||||
|
@ -72,18 +75,6 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
|
||||||
{"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}"""));
|
{"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_WithFilteredObject_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3));
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"url":"url","model_id":"model"}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<CohereCompletionServiceSettings> instanceReader() {
|
protected Writeable.Reader<CohereCompletionServiceSettings> instanceReader() {
|
||||||
return CohereCompletionServiceSettings::new;
|
return CohereCompletionServiceSettings::new;
|
||||||
|
|
|
@ -331,21 +331,6 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
|
||||||
"rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}"""));
|
"rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var serviceSettings = new CohereEmbeddingsServiceSettings(
|
|
||||||
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)),
|
|
||||||
CohereEmbeddingType.INT8
|
|
||||||
);
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """
|
|
||||||
"embedding_type":"byte"}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<CohereEmbeddingsServiceSettings> instanceReader() {
|
protected Writeable.Reader<CohereEmbeddingsServiceSettings> instanceReader() {
|
||||||
return CohereEmbeddingsServiceSettings::new;
|
return CohereEmbeddingsServiceSettings::new;
|
||||||
|
|
|
@ -51,20 +51,6 @@ public class CohereRerankServiceSettingsTests extends AbstractWireSerializingTes
|
||||||
"rate_limit":{"requests_per_minute":3}}"""));
|
"rate_limit":{"requests_per_minute":3}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var serviceSettings = new CohereRerankServiceSettings(
|
|
||||||
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3))
|
|
||||||
);
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
// TODO we probably shouldn't allow configuring these fields for reranking
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id"}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<CohereRerankServiceSettings> instanceReader() {
|
protected Writeable.Reader<CohereRerankServiceSettings> instanceReader() {
|
||||||
return CohereRerankServiceSettings::new;
|
return CohereRerankServiceSettings::new;
|
||||||
|
|
|
@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.aMapWithSize;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.endsWith;
|
import static org.hamcrest.Matchers.endsWith;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -494,7 +493,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -516,7 +515,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
|
||||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
|
@ -28,7 +29,8 @@ public class GoogleAiStudioCompletionModelTests extends ESTestCase {
|
||||||
"service",
|
"service",
|
||||||
new HashMap<>(Map.of("model_id", "model")),
|
new HashMap<>(Map.of("model_id", "model")),
|
||||||
new HashMap<>(Map.of()),
|
new HashMap<>(Map.of()),
|
||||||
null
|
null,
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||||
|
|
||||||
|
@ -31,7 +32,10 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
|
||||||
public void testFromMap_Request_CreatesSettingsCorrectly() {
|
public void testFromMap_Request_CreatesSettingsCorrectly() {
|
||||||
var model = "some model";
|
var model = "some model";
|
||||||
|
|
||||||
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)));
|
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null)));
|
assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null)));
|
||||||
}
|
}
|
||||||
|
@ -47,18 +51,6 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
|
||||||
{"model_id":"model","rate_limit":{"requests_per_minute":360}}"""));
|
{"model_id":"model","rate_limit":{"requests_per_minute":360}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var entity = new GoogleAiStudioCompletionServiceSettings("model", null);
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = entity.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"model_id":"model"}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<GoogleAiStudioCompletionServiceSettings> instanceReader() {
|
protected Writeable.Reader<GoogleAiStudioCompletionServiceSettings> instanceReader() {
|
||||||
return GoogleAiStudioCompletionServiceSettings::new;
|
return GoogleAiStudioCompletionServiceSettings::new;
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||||
|
|
||||||
|
@ -55,7 +56,8 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
|
||||||
ServiceFields.SIMILARITY,
|
ServiceFields.SIMILARITY,
|
||||||
similarity.toString()
|
similarity.toString()
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null)));
|
assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null)));
|
||||||
|
@ -80,23 +82,6 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
|
||||||
}"""));
|
}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null);
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = entity.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
|
||||||
{
|
|
||||||
"model_id":"model",
|
|
||||||
"max_input_tokens": 1024,
|
|
||||||
"dimensions": 8,
|
|
||||||
"similarity": "dot_product"
|
|
||||||
}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<GoogleAiStudioEmbeddingsServiceSettings> instanceReader() {
|
protected Writeable.Reader<GoogleAiStudioEmbeddingsServiceSettings> instanceReader() {
|
||||||
return GoogleAiStudioEmbeddingsServiceSettings::new;
|
return GoogleAiStudioEmbeddingsServiceSettings::new;
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -33,7 +34,6 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.hamcrest.CoreMatchers.is;
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -59,7 +59,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +111,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> secretSettings,
|
Map<String, Object> secretSettings,
|
||||||
String failureMessage
|
String failureMessage,
|
||||||
|
ConfigurationParseContext context
|
||||||
) {
|
) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
@ -57,7 +58,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
var dims = 384;
|
var dims = 384;
|
||||||
var maxInputTokens = 128;
|
var maxInputTokens = 128;
|
||||||
{
|
{
|
||||||
var serviceSettings = HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)));
|
var serviceSettings = HuggingFaceServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(ServiceFields.URL, url)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url)));
|
assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url)));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
|
@ -73,7 +77,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
ServiceFields.MAX_INPUT_TOKENS,
|
ServiceFields.MAX_INPUT_TOKENS,
|
||||||
maxInputTokens
|
maxInputTokens
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
assertThat(
|
assertThat(
|
||||||
serviceSettings,
|
serviceSettings,
|
||||||
|
@ -95,7 +100,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
RateLimitSettings.FIELD_NAME,
|
RateLimitSettings.FIELD_NAME,
|
||||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
assertThat(
|
assertThat(
|
||||||
serviceSettings,
|
serviceSettings,
|
||||||
|
@ -105,7 +111,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFromMap_MissingUrl_ThrowsError() {
|
public void testFromMap_MissingUrl_ThrowsError() {
|
||||||
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceServiceSettings.fromMap(new HashMap<>()));
|
var thrownException = expectThrows(
|
||||||
|
ValidationException.class,
|
||||||
|
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
|
||||||
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
|
@ -118,7 +127,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
public void testFromMap_EmptyUrl_ThrowsError() {
|
public void testFromMap_EmptyUrl_ThrowsError() {
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")))
|
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT)
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -136,7 +145,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
var url = "https://www.abc^.com";
|
var url = "https://www.abc^.com";
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)))
|
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT)
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -152,7 +161,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
var similarity = "by_size";
|
var similarity = "by_size";
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)))
|
() -> HuggingFaceServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -175,18 +187,6 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
||||||
{"url":"url","rate_limit":{"requests_per_minute":3}}"""));
|
{"url":"url","rate_limit":{"requests_per_minute":3}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var serviceSettings = new HuggingFaceServiceSettings(ServiceUtils.createUri("url"), null, null, null, new RateLimitSettings(3));
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"url":"url"}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<HuggingFaceServiceSettings> instanceReader() {
|
protected Writeable.Reader<HuggingFaceServiceSettings> instanceReader() {
|
||||||
return HuggingFaceServiceSettings::new;
|
return HuggingFaceServiceSettings::new;
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -32,7 +33,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
||||||
|
|
||||||
public void testFromMap() {
|
public void testFromMap() {
|
||||||
var url = "https://www.abc.com";
|
var url = "https://www.abc.com";
|
||||||
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)));
|
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings));
|
assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings));
|
||||||
}
|
}
|
||||||
|
@ -40,7 +44,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
||||||
public void testFromMap_EmptyUrl_ThrowsError() {
|
public void testFromMap_EmptyUrl_ThrowsError() {
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")))
|
() -> HuggingFaceElserServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -55,7 +62,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFromMap_MissingUrl_ThrowsError() {
|
public void testFromMap_MissingUrl_ThrowsError() {
|
||||||
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>()));
|
var thrownException = expectThrows(
|
||||||
|
ValidationException.class,
|
||||||
|
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
|
||||||
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
|
@ -72,7 +82,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
||||||
var url = "https://www.abc^.com";
|
var url = "https://www.abc^.com";
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)))
|
() -> HuggingFaceElserServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -98,18 +111,6 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
||||||
{"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}"""));
|
{"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var serviceSettings = new HuggingFaceElserServiceSettings(ServiceUtils.createUri("url"), new RateLimitSettings(3));
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"url":"url","max_input_tokens":512}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<HuggingFaceElserServiceSettings> instanceReader() {
|
protected Writeable.Reader<HuggingFaceElserServiceSettings> instanceReader() {
|
||||||
return HuggingFaceElserServiceSettings::new;
|
return HuggingFaceElserServiceSettings::new;
|
||||||
|
|
|
@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -393,7 +392,7 @@ public class MistralServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -415,7 +414,7 @@ public class MistralServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,18 +98,6 @@ public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
|
||||||
"rate_limit":{"requests_per_minute":3}}"""));
|
"rate_limit":{"requests_per_minute":3}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToFilteredXContent_WritesFilteredValues() throws IOException {
|
|
||||||
var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = entity.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, CoreMatchers.is("""
|
|
||||||
{"model":"model_name","dimensions":1024,"max_input_tokens":512}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
|
public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
|
||||||
var outputBuffer = new BytesStreamOutput();
|
var outputBuffer = new BytesStreamOutput();
|
||||||
var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
|
var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
|
||||||
|
|
|
@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.mockito.ArgumentMatchers.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -675,7 +674,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender(anyString())).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name");
|
var mockModel = getInvalidModel("model_id", "service_name");
|
||||||
|
|
||||||
|
@ -697,7 +696,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender(anyString());
|
verify(factory, times(1)).createSender();
|
||||||
verify(sender, times(1)).start();
|
verify(sender, times(1)).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
|
||||||
|
@ -48,7 +49,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
ServiceFields.MAX_INPUT_TOKENS,
|
ServiceFields.MAX_INPUT_TOKENS,
|
||||||
maxInputTokens
|
maxInputTokens
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -77,7 +79,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
RateLimitSettings.FIELD_NAME,
|
RateLimitSettings.FIELD_NAME,
|
||||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit))
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit))
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -101,7 +104,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
ServiceFields.MAX_INPUT_TOKENS,
|
ServiceFields.MAX_INPUT_TOKENS,
|
||||||
maxInputTokens
|
maxInputTokens
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertNull(serviceSettings.uri());
|
assertNull(serviceSettings.uri());
|
||||||
|
@ -113,7 +117,10 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
public void testFromMap_EmptyUrl_ThrowsError() {
|
public void testFromMap_EmptyUrl_ThrowsError() {
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> OpenAiChatCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")))
|
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -132,7 +139,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
var maxInputTokens = 8192;
|
var maxInputTokens = 8192;
|
||||||
|
|
||||||
var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap(
|
var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap(
|
||||||
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens))
|
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
);
|
);
|
||||||
|
|
||||||
assertNull(serviceSettings.uri());
|
assertNull(serviceSettings.uri());
|
||||||
|
@ -144,7 +152,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
||||||
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model"))
|
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -164,7 +173,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ValidationException.class,
|
ValidationException.class,
|
||||||
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
||||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model"))
|
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -213,19 +223,6 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
||||||
{"model_id":"model","rate_limit":{"requests_per_minute":500}}"""));
|
{"model_id":"model","rate_limit":{"requests_per_minute":500}}"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
|
||||||
var serviceSettings = new OpenAiChatCompletionServiceSettings("model", "url", "org", 1024, new RateLimitSettings(2));
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
|
||||||
filteredXContent.toXContent(builder, null);
|
|
||||||
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, is("""
|
|
||||||
{"model_id":"model","url":"url","organization_id":"org",""" + """
|
|
||||||
"max_input_tokens":1024}"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<OpenAiChatCompletionServiceSettings> instanceReader() {
|
protected Writeable.Reader<OpenAiChatCompletionServiceSettings> instanceReader() {
|
||||||
return OpenAiChatCompletionServiceSettings::new;
|
return OpenAiChatCompletionServiceSettings::new;
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue