mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 23:27:25 -04:00
[ML] Inference API rate limit queuing logic refactor (#107706)
* Adding new executor * Adding in queuing logic * working tests * Added cleanup task * Update docs/changelog/107706.yaml * Updating yml * deregistering callbacks for settings changes * Cleaning up code * Update docs/changelog/107706.yaml * Fixing rate limit settings bug and only sleeping least amount * Removing debug logging * Removing commented code * Renaming feedback * fixing tests * Updating docs and validation * Fixing source blocks * Adjusting cancel logic * Reformatting ascii * Addressing feedback * adding rate limiting for google embeddings and mistral --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
cd84749d87
commit
fdb5058b13
102 changed files with 1499 additions and 937 deletions
5
docs/changelog/107706.yaml
Normal file
5
docs/changelog/107706.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 107706
|
||||
summary: Add rate limiting support for the Inference API
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -7,21 +7,17 @@ experimental[]
|
|||
Creates an {infer} endpoint to perform an {infer} task.
|
||||
|
||||
IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in
|
||||
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure
|
||||
OpenAI, Google AI Studio or Hugging Face. For built-in models and models
|
||||
uploaded though Eland, the {infer} APIs offer an alternative way to use and
|
||||
manage trained models. However, if you do not plan to use the {infer} APIs to
|
||||
use these models or if you want to use non-NLP models, use the
|
||||
{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face.
|
||||
For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models.
|
||||
However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the
|
||||
<<ml-df-trained-models-apis>>.
|
||||
|
||||
|
||||
[discrete]
|
||||
[[put-inference-api-request]]
|
||||
==== {api-request-title}
|
||||
|
||||
`PUT /_inference/<task_type>/<inference_id>`
|
||||
|
||||
|
||||
[discrete]
|
||||
[[put-inference-api-prereqs]]
|
||||
==== {api-prereq-title}
|
||||
|
@ -29,7 +25,6 @@ use these models or if you want to use non-NLP models, use the
|
|||
* Requires the `manage_inference` <<privileges-list-cluster,cluster privilege>>
|
||||
(the built-in `inference_admin` role grants this privilege)
|
||||
|
||||
|
||||
[discrete]
|
||||
[[put-inference-api-desc]]
|
||||
==== {api-description-title}
|
||||
|
@ -48,25 +43,23 @@ The following services are available through the {infer} API:
|
|||
* Hugging Face
|
||||
* OpenAI
|
||||
|
||||
|
||||
[discrete]
|
||||
[[put-inference-api-path-params]]
|
||||
==== {api-path-parms-title}
|
||||
|
||||
|
||||
`<inference_id>`::
|
||||
(Required, string)
|
||||
The unique identifier of the {infer} endpoint.
|
||||
|
||||
`<task_type>`::
|
||||
(Required, string)
|
||||
The type of the {infer} task that the model will perform. Available task types:
|
||||
The type of the {infer} task that the model will perform.
|
||||
Available task types:
|
||||
* `completion`,
|
||||
* `rerank`,
|
||||
* `sparse_embedding`,
|
||||
* `text_embedding`.
|
||||
|
||||
|
||||
[discrete]
|
||||
[[put-inference-api-request-body]]
|
||||
==== {api-request-body-title}
|
||||
|
@ -78,21 +71,18 @@ Available services:
|
|||
|
||||
* `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service.
|
||||
* `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service.
|
||||
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the
|
||||
Cohere service.
|
||||
* `elasticsearch`: specify the `text_embedding` task type to use the E5
|
||||
built-in model or text embedding models uploaded by Eland.
|
||||
* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service.
|
||||
* `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland.
|
||||
* `elser`: specify the `sparse_embedding` task type to use the ELSER service.
|
||||
* `googleaistudio`: specify the `completion` task to use the Google AI Studio service.
|
||||
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face
|
||||
service.
|
||||
* `openai`: specify the `completion` or `text_embedding` task type to use the
|
||||
OpenAI service.
|
||||
* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service.
|
||||
* `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service.
|
||||
|
||||
|
||||
`service_settings`::
|
||||
(Required, object)
|
||||
Settings used to install the {infer} model. These settings are specific to the
|
||||
Settings used to install the {infer} model.
|
||||
These settings are specific to the
|
||||
`service` you specified.
|
||||
+
|
||||
.`service_settings` for the `azureaistudio` service
|
||||
|
@ -104,11 +94,10 @@ Settings used to install the {infer} model. These settings are specific to the
|
|||
A valid API key of your Azure AI Studio model deployment.
|
||||
This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account.
|
||||
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
||||
creating the {infer} model, you cannot change the associated API key. If you
|
||||
want to use a different API key, delete the {infer} model and recreate it with
|
||||
the same name and the updated API key.
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||
The <<get-inference-api>> does not retrieve your API key.
|
||||
After creating the {infer} model, you cannot change the associated API key.
|
||||
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||
|
||||
`target`:::
|
||||
(Required, string)
|
||||
|
@ -142,11 +131,13 @@ For "real-time" endpoints which are billed per hour of usage, specify `realtime`
|
|||
By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`.
|
||||
This helps to minimize the number of rate limit errors returned from Azure AI Studio.
|
||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||
```
|
||||
+
|
||||
[source,text]
|
||||
----
|
||||
"rate_limit": {
|
||||
"requests_per_minute": <<number_of_requests>>
|
||||
}
|
||||
```
|
||||
----
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `azureopenai` service
|
||||
|
@ -181,6 +172,22 @@ Your Azure OpenAI deployments can be found though the https://oai.azure.com/[Azu
|
|||
The Azure API version ID to use.
|
||||
We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version].
|
||||
|
||||
`rate_limit`:::
|
||||
(Optional, object)
|
||||
The `azureopenai` service sets a default number of requests allowed per minute depending on the task type.
|
||||
For `text_embedding` it is set to `1440`.
|
||||
For `completion` it is set to `120`.
|
||||
This helps to minimize the number of rate limit errors returned from Azure.
|
||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||
+
|
||||
[source,text]
|
||||
----
|
||||
"rate_limit": {
|
||||
"requests_per_minute": <<number_of_requests>>
|
||||
}
|
||||
----
|
||||
+
|
||||
More information about the rate limits for Azure can be found in the https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits[Quota limits docs] and https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest[How to change the quotas].
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `cohere` service
|
||||
|
@ -188,24 +195,24 @@ We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/opena
|
|||
=====
|
||||
`api_key`:::
|
||||
(Required, string)
|
||||
A valid API key of your Cohere account. You can find your Cohere API keys or you
|
||||
can create a new one
|
||||
A valid API key of your Cohere account.
|
||||
You can find your Cohere API keys or you can create a new one
|
||||
https://dashboard.cohere.com/api-keys[on the API keys settings page].
|
||||
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
||||
creating the {infer} model, you cannot change the associated API key. If you
|
||||
want to use a different API key, delete the {infer} model and recreate it with
|
||||
the same name and the updated API key.
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||
The <<get-inference-api>> does not retrieve your API key.
|
||||
After creating the {infer} model, you cannot change the associated API key.
|
||||
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||
|
||||
`embedding_type`::
|
||||
(Optional, string)
|
||||
Only for `text_embedding`. Specifies the types of embeddings you want to get
|
||||
back. Defaults to `float`.
|
||||
Only for `text_embedding`.
|
||||
Specifies the types of embeddings you want to get back.
|
||||
Defaults to `float`.
|
||||
Valid values are:
|
||||
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
|
||||
* `float`: use it for the default float embeddings.
|
||||
* `int8`: use it for signed int8 embeddings.
|
||||
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
|
||||
* `float`: use it for the default float embeddings.
|
||||
* `int8`: use it for signed int8 embeddings.
|
||||
|
||||
`model_id`::
|
||||
(Optional, string)
|
||||
|
@ -214,50 +221,68 @@ To review the available `rerank` models, refer to the
|
|||
https://docs.cohere.com/reference/rerank-1[Cohere docs].
|
||||
|
||||
To review the available `text_embedding` models, refer to the
|
||||
https://docs.cohere.com/reference/embed[Cohere docs]. The default value for
|
||||
https://docs.cohere.com/reference/embed[Cohere docs].
|
||||
The default value for
|
||||
`text_embedding` is `embed-english-v2.0`.
|
||||
|
||||
`rate_limit`:::
|
||||
(Optional, object)
|
||||
By default, the `cohere` service sets the number of requests allowed per minute to `10000`.
|
||||
This value is the same for all task types.
|
||||
This helps to minimize the number of rate limit errors returned from Cohere.
|
||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||
+
|
||||
[source,text]
|
||||
----
|
||||
"rate_limit": {
|
||||
"requests_per_minute": <<number_of_requests>>
|
||||
}
|
||||
----
|
||||
+
|
||||
More information about Cohere's rate limits can be found in https://docs.cohere.com/docs/going-live#production-key-specifications[Cohere's production key docs].
|
||||
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `elasticsearch` service
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`model_id`:::
|
||||
(Required, string)
|
||||
The name of the model to use for the {infer} task. It can be the
|
||||
ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or
|
||||
a text embedding model already
|
||||
The name of the model to use for the {infer} task.
|
||||
It can be the ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or a text embedding model already
|
||||
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
|
||||
|
||||
`num_allocations`:::
|
||||
(Required, integer)
|
||||
The number of model allocations to create. `num_allocations` must not exceed the
|
||||
number of available processors per node divided by the `num_threads`.
|
||||
The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
|
||||
|
||||
`num_threads`:::
|
||||
(Required, integer)
|
||||
The number of threads to use by each model allocation. `num_threads` must not
|
||||
exceed the number of available processors per node divided by the number of
|
||||
allocations. Must be a power of 2. Max allowed value is 32.
|
||||
The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
|
||||
Must be a power of 2. Max allowed value is 32.
|
||||
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `elser` service
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`num_allocations`:::
|
||||
(Required, integer)
|
||||
The number of model allocations to create. `num_allocations` must not exceed the
|
||||
number of available processors per node divided by the `num_threads`.
|
||||
The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`.
|
||||
|
||||
`num_threads`:::
|
||||
(Required, integer)
|
||||
The number of threads to use by each model allocation. `num_threads` must not
|
||||
exceed the number of available processors per node divided by the number of
|
||||
allocations. Must be a power of 2. Max allowed value is 32.
|
||||
The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations.
|
||||
Must be a power of 2. Max allowed value is 32.
|
||||
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `googleiastudio` service
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`api_key`:::
|
||||
(Required, string)
|
||||
A valid API key for the Google Gemini API.
|
||||
|
@ -274,76 +299,113 @@ This helps to minimize the number of rate limit errors returned from Google AI S
|
|||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||
+
|
||||
--
|
||||
```
|
||||
[source,text]
|
||||
----
|
||||
"rate_limit": {
|
||||
"requests_per_minute": <<number_of_requests>>
|
||||
}
|
||||
```
|
||||
----
|
||||
--
|
||||
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `hugging_face` service
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`api_key`:::
|
||||
(Required, string)
|
||||
A valid access token of your Hugging Face account. You can find your Hugging
|
||||
Face access tokens or you can create a new one
|
||||
A valid access token of your Hugging Face account.
|
||||
You can find your Hugging Face access tokens or you can create a new one
|
||||
https://huggingface.co/settings/tokens[on the settings page].
|
||||
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
||||
creating the {infer} model, you cannot change the associated API key. If you
|
||||
want to use a different API key, delete the {infer} model and recreate it with
|
||||
the same name and the updated API key.
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||
The <<get-inference-api>> does not retrieve your API key.
|
||||
After creating the {infer} model, you cannot change the associated API key.
|
||||
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||
|
||||
`url`:::
|
||||
(Required, string)
|
||||
The URL endpoint to use for the requests.
|
||||
|
||||
`rate_limit`:::
|
||||
(Optional, object)
|
||||
By default, the `huggingface` service sets the number of requests allowed per minute to `3000`.
|
||||
This helps to minimize the number of rate limit errors returned from Hugging Face.
|
||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||
+
|
||||
[source,text]
|
||||
----
|
||||
"rate_limit": {
|
||||
"requests_per_minute": <<number_of_requests>>
|
||||
}
|
||||
----
|
||||
|
||||
=====
|
||||
+
|
||||
.`service_settings` for the `openai` service
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`api_key`:::
|
||||
(Required, string)
|
||||
A valid API key of your OpenAI account. You can find your OpenAI API keys in
|
||||
your OpenAI account under the
|
||||
A valid API key of your OpenAI account.
|
||||
You can find your OpenAI API keys in your OpenAI account under the
|
||||
https://platform.openai.com/api-keys[API keys section].
|
||||
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model
|
||||
creation. The <<get-inference-api>> does not retrieve your API key. After
|
||||
creating the {infer} model, you cannot change the associated API key. If you
|
||||
want to use a different API key, delete the {infer} model and recreate it with
|
||||
the same name and the updated API key.
|
||||
IMPORTANT: You need to provide the API key only once, during the {infer} model creation.
|
||||
The <<get-inference-api>> does not retrieve your API key.
|
||||
After creating the {infer} model, you cannot change the associated API key.
|
||||
If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key.
|
||||
|
||||
`model_id`:::
|
||||
(Required, string)
|
||||
The name of the model to use for the {infer} task. Refer to the
|
||||
The name of the model to use for the {infer} task.
|
||||
Refer to the
|
||||
https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation]
|
||||
for the list of available text embedding models.
|
||||
|
||||
`organization_id`:::
|
||||
(Optional, string)
|
||||
The unique identifier of your organization. You can find the Organization ID in
|
||||
your OpenAI account under
|
||||
The unique identifier of your organization.
|
||||
You can find the Organization ID in your OpenAI account under
|
||||
https://platform.openai.com/account/organization[**Settings** > **Organizations**].
|
||||
|
||||
`url`:::
|
||||
(Optional, string)
|
||||
The URL endpoint to use for the requests. Can be changed for testing purposes.
|
||||
The URL endpoint to use for the requests.
|
||||
Can be changed for testing purposes.
|
||||
Defaults to `https://api.openai.com/v1/embeddings`.
|
||||
|
||||
`rate_limit`:::
|
||||
(Optional, object)
|
||||
The `openai` service sets a default number of requests allowed per minute depending on the task type.
|
||||
For `text_embedding` it is set to `3000`.
|
||||
For `completion` it is set to `500`.
|
||||
This helps to minimize the number of rate limit errors returned from Azure.
|
||||
To modify this, set the `requests_per_minute` setting of this object in your service settings:
|
||||
+
|
||||
[source,text]
|
||||
----
|
||||
"rate_limit": {
|
||||
"requests_per_minute": <<number_of_requests>>
|
||||
}
|
||||
----
|
||||
+
|
||||
More information about the rate limits for OpenAI can be found in your https://platform.openai.com/account/limits[Account limits].
|
||||
|
||||
=====
|
||||
|
||||
`task_settings`::
|
||||
(Optional, object)
|
||||
Settings to configure the {infer} task. These settings are specific to the
|
||||
Settings to configure the {infer} task.
|
||||
These settings are specific to the
|
||||
`<task_type>` you specified.
|
||||
+
|
||||
.`task_settings` for the `completion` task type
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`do_sample`:::
|
||||
(Optional, float)
|
||||
For the `azureaistudio` service only.
|
||||
|
@ -358,8 +420,8 @@ Defaults to 64.
|
|||
|
||||
`user`:::
|
||||
(Optional, string)
|
||||
For `openai` service only. Specifies the user issuing the request, which can be
|
||||
used for abuse detection.
|
||||
For `openai` service only.
|
||||
Specifies the user issuing the request, which can be used for abuse detection.
|
||||
|
||||
`temperature`:::
|
||||
(Optional, float)
|
||||
|
@ -378,45 +440,46 @@ Should not be used if `temperature` is specified.
|
|||
.`task_settings` for the `rerank` task type
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`return_documents`::
|
||||
(Optional, boolean)
|
||||
For `cohere` service only. Specify whether to return doc text within the
|
||||
results.
|
||||
For `cohere` service only.
|
||||
Specify whether to return doc text within the results.
|
||||
|
||||
`top_n`::
|
||||
(Optional, integer)
|
||||
The number of most relevant documents to return, defaults to the number of the
|
||||
documents.
|
||||
The number of most relevant documents to return, defaults to the number of the documents.
|
||||
|
||||
=====
|
||||
+
|
||||
.`task_settings` for the `text_embedding` task type
|
||||
[%collapsible%closed]
|
||||
=====
|
||||
|
||||
`input_type`:::
|
||||
(Optional, string)
|
||||
For `cohere` service only. Specifies the type of input passed to the model.
|
||||
For `cohere` service only.
|
||||
Specifies the type of input passed to the model.
|
||||
Valid values are:
|
||||
* `classification`: use it for embeddings passed through a text classifier.
|
||||
* `clusterning`: use it for the embeddings run through a clustering algorithm.
|
||||
* `ingest`: use it for storing document embeddings in a vector database.
|
||||
* `search`: use it for storing embeddings of search queries run against a
|
||||
vector database to find relevant documents.
|
||||
* `classification`: use it for embeddings passed through a text classifier.
|
||||
* `clusterning`: use it for the embeddings run through a clustering algorithm.
|
||||
* `ingest`: use it for storing document embeddings in a vector database.
|
||||
* `search`: use it for storing embeddings of search queries run against a vector database to find relevant documents.
|
||||
|
||||
`truncate`:::
|
||||
(Optional, string)
|
||||
For `cohere` service only. Specifies how the API handles inputs longer than the
|
||||
maximum token length. Defaults to `END`. Valid values are:
|
||||
* `NONE`: when the input exceeds the maximum input token length an error is
|
||||
returned.
|
||||
* `START`: when the input exceeds the maximum input token length the start of
|
||||
the input is discarded.
|
||||
* `END`: when the input exceeds the maximum input token length the end of
|
||||
the input is discarded.
|
||||
For `cohere` service only.
|
||||
Specifies how the API handles inputs longer than the maximum token length.
|
||||
Defaults to `END`.
|
||||
Valid values are:
|
||||
* `NONE`: when the input exceeds the maximum input token length an error is returned.
|
||||
* `START`: when the input exceeds the maximum input token length the start of the input is discarded.
|
||||
* `END`: when the input exceeds the maximum input token length the end of the input is discarded.
|
||||
|
||||
`user`:::
|
||||
(optional, string)
|
||||
For `openai`, `azureopenai` and `azureaistudio` services only. Specifies the user issuing the
|
||||
request, which can be used for abuse detection.
|
||||
For `openai`, `azureopenai` and `azureaistudio` services only.
|
||||
Specifies the user issuing the request, which can be used for abuse detection.
|
||||
|
||||
=====
|
||||
[discrete]
|
||||
|
@ -470,7 +533,6 @@ PUT _inference/completion/azure_ai_studio_completion
|
|||
|
||||
The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer].
|
||||
|
||||
|
||||
[discrete]
|
||||
[[inference-example-azureopenai]]
|
||||
===== Azure OpenAI service
|
||||
|
@ -519,7 +581,6 @@ The list of chat completion models that you can choose from in your Azure OpenAI
|
|||
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models]
|
||||
* https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5]
|
||||
|
||||
|
||||
[discrete]
|
||||
[[inference-example-cohere]]
|
||||
===== Cohere service
|
||||
|
@ -565,7 +626,6 @@ PUT _inference/rerank/cohere-rerank
|
|||
For more examples, also review the
|
||||
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
|
||||
|
||||
|
||||
[discrete]
|
||||
[[inference-example-e5]]
|
||||
===== E5 via the `elasticsearch` service
|
||||
|
@ -586,10 +646,9 @@ PUT _inference/text_embedding/my-e5-model
|
|||
}
|
||||
------------------------------------------------------------
|
||||
// TEST[skip:TBD]
|
||||
<1> The `model_id` must be the ID of one of the built-in E5 models. Valid values
|
||||
are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. For
|
||||
further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
|
||||
|
||||
<1> The `model_id` must be the ID of one of the built-in E5 models.
|
||||
Valid values are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`.
|
||||
For further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
|
||||
|
||||
[discrete]
|
||||
[[inference-example-elser]]
|
||||
|
@ -597,8 +656,7 @@ further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation].
|
|||
|
||||
The following example shows how to create an {infer} endpoint called
|
||||
`my-elser-model` to perform a `sparse_embedding` task type.
|
||||
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more
|
||||
info.
|
||||
Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more info.
|
||||
|
||||
[source,console]
|
||||
------------------------------------------------------------
|
||||
|
@ -672,16 +730,17 @@ PUT _inference/text_embedding/hugging-face-embeddings
|
|||
}
|
||||
------------------------------------------------------------
|
||||
// TEST[skip:TBD]
|
||||
<1> A valid Hugging Face access token. You can find on the
|
||||
<1> A valid Hugging Face access token.
|
||||
You can find on the
|
||||
https://huggingface.co/settings/tokens[settings page of your account].
|
||||
<2> The {infer} endpoint URL you created on Hugging Face.
|
||||
|
||||
Create a new {infer} endpoint on
|
||||
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an
|
||||
endpoint URL. Select the model you want to use on the new endpoint creation page
|
||||
- for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
|
||||
task under the Advanced configuration section. Create the endpoint. Copy the URL
|
||||
after the endpoint initialization has been finished.
|
||||
https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an endpoint URL.
|
||||
Select the model you want to use on the new endpoint creation page - for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings`
|
||||
task under the Advanced configuration section.
|
||||
Create the endpoint.
|
||||
Copy the URL after the endpoint initialization has been finished.
|
||||
|
||||
[discrete]
|
||||
[[inference-example-hugging-face-supported-models]]
|
||||
|
@ -695,7 +754,6 @@ The list of recommended models for the Hugging Face service:
|
|||
* https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base]
|
||||
* https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small]
|
||||
|
||||
|
||||
[discrete]
|
||||
[[inference-example-eland]]
|
||||
===== Models uploaded by Eland via the elasticsearch service
|
||||
|
@ -716,11 +774,9 @@ PUT _inference/text_embedding/my-msmarco-minilm-model
|
|||
}
|
||||
------------------------------------------------------------
|
||||
// TEST[skip:TBD]
|
||||
<1> The `model_id` must be the ID of a text embedding model which has already
|
||||
been
|
||||
<1> The `model_id` must be the ID of a text embedding model which has already been
|
||||
{ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland].
|
||||
|
||||
|
||||
[discrete]
|
||||
[[inference-example-openai]]
|
||||
===== OpenAI service
|
||||
|
@ -756,4 +812,3 @@ PUT _inference/completion/openai-completion
|
|||
}
|
||||
------------------------------------------------------------
|
||||
// TEST[skip:TBD]
|
||||
|
||||
|
|
|
@ -88,6 +88,13 @@ public class TimeValue implements Comparable<TimeValue> {
|
|||
return new TimeValue(days, TimeUnit.DAYS);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the {@link TimeValue} object that has the least duration.
|
||||
*/
|
||||
public static TimeValue min(TimeValue time1, TimeValue time2) {
|
||||
return time1.compareTo(time2) < 0 ? time1 : time2;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the unit used for the this time value, see {@link #duration()}
|
||||
*/
|
||||
|
|
|
@ -17,6 +17,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
|
|||
import static org.hamcrest.CoreMatchers.not;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
import static org.hamcrest.object.HasToString.hasToString;
|
||||
|
||||
|
@ -231,6 +232,12 @@ public class TimeValueTests extends ESTestCase {
|
|||
assertThat(ex.getMessage(), containsString("duration cannot be negative"));
|
||||
}
|
||||
|
||||
public void testMin() {
|
||||
assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0)));
|
||||
assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1)));
|
||||
assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE));
|
||||
}
|
||||
|
||||
private TimeUnit randomTimeUnitObject() {
|
||||
return randomFrom(
|
||||
TimeUnit.NANOSECONDS,
|
||||
|
|
|
@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor {
|
|||
private final ServiceComponents serviceComponents;
|
||||
|
||||
public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
|
||||
// TODO Batching - accept a class that can handle batching
|
||||
this.sender = Objects.requireNonNull(sender);
|
||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ public class CohereEmbeddingsAction implements ExecutableAction {
|
|||
model.getServiceSettings().getCommonSettings().uri(),
|
||||
"Cohere embeddings"
|
||||
);
|
||||
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
|
||||
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -37,17 +36,16 @@ public class AzureAiStudioChatCompletionRequestManager extends AzureAiStudioRequ
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
private static ResponseHandler createCompletionHandler() {
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -41,17 +40,16 @@ public class AzureAiStudioEmbeddingsRequestManager extends AzureAiStudioRequestM
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||
AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
private static ResponseHandler createEmbeddingsHandler() {
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -43,16 +42,15 @@ public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManag
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
@Nullable String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -55,16 +54,15 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,11 +38,16 @@ abstract class BaseRequestManager implements RequestManager {
|
|||
|
||||
@Override
|
||||
public Object rateLimitGrouping() {
|
||||
return rateLimitGroup;
|
||||
// It's possible that two inference endpoints have the same information defining the group but have different
|
||||
// rate limits then they should be in different groups otherwise whoever initially created the group will set
|
||||
// the rate and the other inference endpoint's rate will be ignored
|
||||
return new EndpointGrouping(rateLimitGroup, rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return rateLimitSettings;
|
||||
}
|
||||
|
||||
private record EndpointGrouping(Object group, RateLimitSettings settings) {}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -46,16 +45,15 @@ public class CohereCompletionRequestManager extends CohereRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
CohereCompletionRequest request = new CohereCompletionRequest(input, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -44,16 +43,15 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -44,16 +43,15 @@ public class CohereRerankRequestManager extends CohereRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
CohereRerankRequest request = new CohereRerankRequest(query, input, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ record ExecutableInferenceRequest(
|
|||
RequestSender requestSender,
|
||||
Logger logger,
|
||||
Request request,
|
||||
HttpClientContext context,
|
||||
ResponseHandler responseHandler,
|
||||
Supplier<Boolean> hasFinished,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
|
@ -34,7 +33,7 @@ record ExecutableInferenceRequest(
|
|||
var inferenceEntityId = request.createHttpRequest().inferenceEntityId();
|
||||
|
||||
try {
|
||||
requestSender.send(logger, request, context, hasFinished, responseHandler, listener);
|
||||
requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener);
|
||||
} catch (Exception e) {
|
||||
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
|
||||
logger.warn(errorMessage, e);
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -42,15 +41,14 @@ public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioReques
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -48,17 +47,16 @@ public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioReques
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||
GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ import org.elasticsearch.core.TimeValue;
|
|||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
|
@ -39,30 +41,28 @@ public class HttpRequestSender implements Sender {
|
|||
private final ServiceComponents serviceComponents;
|
||||
private final HttpClientManager httpClientManager;
|
||||
private final ClusterService clusterService;
|
||||
private final SingleRequestManager requestManager;
|
||||
private final RequestSender requestSender;
|
||||
|
||||
public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
|
||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||
this.httpClientManager = Objects.requireNonNull(httpClientManager);
|
||||
this.clusterService = Objects.requireNonNull(clusterService);
|
||||
|
||||
var requestSender = new RetryingHttpSender(
|
||||
requestSender = new RetryingHttpSender(
|
||||
this.httpClientManager.getHttpClient(),
|
||||
serviceComponents.throttlerManager(),
|
||||
new RetrySettings(serviceComponents.settings(), clusterService),
|
||||
serviceComponents.threadPool()
|
||||
);
|
||||
requestManager = new SingleRequestManager(requestSender);
|
||||
}
|
||||
|
||||
public Sender createSender(String serviceName) {
|
||||
public Sender createSender() {
|
||||
return new HttpRequestSender(
|
||||
serviceName,
|
||||
serviceComponents.threadPool(),
|
||||
httpClientManager,
|
||||
clusterService,
|
||||
serviceComponents.settings(),
|
||||
requestManager
|
||||
requestSender
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -71,26 +71,24 @@ public class HttpRequestSender implements Sender {
|
|||
|
||||
private final ThreadPool threadPool;
|
||||
private final HttpClientManager manager;
|
||||
private final RequestExecutorService service;
|
||||
private final RequestExecutor service;
|
||||
private final AtomicBoolean started = new AtomicBoolean(false);
|
||||
private final CountDownLatch startCompleted = new CountDownLatch(1);
|
||||
|
||||
private HttpRequestSender(
|
||||
String serviceName,
|
||||
ThreadPool threadPool,
|
||||
HttpClientManager httpClientManager,
|
||||
ClusterService clusterService,
|
||||
Settings settings,
|
||||
SingleRequestManager requestManager
|
||||
RequestSender requestSender
|
||||
) {
|
||||
this.threadPool = Objects.requireNonNull(threadPool);
|
||||
this.manager = Objects.requireNonNull(httpClientManager);
|
||||
service = new RequestExecutorService(
|
||||
serviceName,
|
||||
threadPool,
|
||||
startCompleted,
|
||||
new RequestExecutorServiceSettings(settings, clusterService),
|
||||
requestManager
|
||||
requestSender
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -55,26 +54,17 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var truncatedInput = truncate(input, model.getTokenLimit());
|
||||
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
|
||||
|
||||
return new ExecutableInferenceRequest(
|
||||
requestSender,
|
||||
logger,
|
||||
request,
|
||||
context,
|
||||
responseHandler,
|
||||
hasRequestCompletedFunction,
|
||||
listener
|
||||
);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
record RateLimitGrouping(int accountHash) {
|
||||
|
|
|
@ -19,9 +19,9 @@ import java.util.function.Supplier;
|
|||
public interface InferenceRequest {
|
||||
|
||||
/**
|
||||
* Returns the creator that handles building an executable request based on the input provided.
|
||||
* Returns the manager that handles building and executing an inference request.
|
||||
*/
|
||||
RequestManager getRequestCreator();
|
||||
RequestManager getRequestManager();
|
||||
|
||||
/**
|
||||
* Returns the query associated with this request. Used for Rerank tasks.
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -51,18 +50,17 @@ public class MistralEmbeddingsRequestManager extends BaseRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||
MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
record RateLimitGrouping(int keyHashCode) {
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
class NoopTask implements RejectableTask {
|
||||
|
||||
@Override
|
||||
public RequestManager getRequestCreator() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQuery() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getInput() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionListener<InferenceServiceResults> getListener() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasCompleted() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Supplier<Boolean> getRequestCompletedFunction() {
|
||||
return () -> true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onRejection(Exception e) {
|
||||
|
||||
}
|
||||
}
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -43,17 +42,16 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
@Nullable String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
private static ResponseHandler createCompletionHandler() {
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -55,17 +54,16 @@ public class OpenAiEmbeddingsRequestManager extends OpenAiRequestManager {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Runnable create(
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
|
||||
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model);
|
||||
|
||||
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
|
@ -17,21 +16,31 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.core.Strings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.threadpool.Scheduler;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue;
|
||||
import org.elasticsearch.xpack.inference.common.RateLimiter;
|
||||
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.time.Clock;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
|
||||
|
||||
/**
|
||||
* A service for queuing and executing {@link RequestTask}. This class is useful because the
|
||||
|
@ -45,7 +54,18 @@ import static org.elasticsearch.core.Strings.format;
|
|||
* {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info.
|
||||
*/
|
||||
class RequestExecutorService implements RequestExecutor {
|
||||
private static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> QUEUE_CREATOR =
|
||||
|
||||
/**
|
||||
* Provides dependency injection mainly for testing
|
||||
*/
|
||||
interface Sleeper {
|
||||
void sleep(TimeValue sleepTime) throws InterruptedException;
|
||||
}
|
||||
|
||||
// default for tests
|
||||
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
|
||||
// default for tests
|
||||
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
|
||||
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
|
||||
@Override
|
||||
public BlockingQueue<RejectableTask> create(int capacity) {
|
||||
|
@ -65,86 +85,116 @@ class RequestExecutorService implements RequestExecutor {
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Provides dependency injection mainly for testing
|
||||
*/
|
||||
interface RateLimiterCreator {
|
||||
RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit);
|
||||
}
|
||||
|
||||
// default for testing
|
||||
static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new;
|
||||
private static final Logger logger = LogManager.getLogger(RequestExecutorService.class);
|
||||
private final String serviceName;
|
||||
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
|
||||
private final AtomicBoolean running = new AtomicBoolean(true);
|
||||
private final CountDownLatch terminationLatch = new CountDownLatch(1);
|
||||
private final HttpClientContext httpContext;
|
||||
private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
|
||||
|
||||
private final ConcurrentMap<Object, RateLimitingEndpointHandler> rateLimitGroupings = new ConcurrentHashMap<>();
|
||||
private final ThreadPool threadPool;
|
||||
private final CountDownLatch startupLatch;
|
||||
private final BlockingQueue<Runnable> controlQueue = new LinkedBlockingQueue<>();
|
||||
private final SingleRequestManager requestManager;
|
||||
private final CountDownLatch terminationLatch = new CountDownLatch(1);
|
||||
private final RequestSender requestSender;
|
||||
private final RequestExecutorServiceSettings settings;
|
||||
private final Clock clock;
|
||||
private final AtomicBoolean shutdown = new AtomicBoolean(false);
|
||||
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
|
||||
private final Sleeper sleeper;
|
||||
private final RateLimiterCreator rateLimiterCreator;
|
||||
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
|
||||
private final AtomicBoolean started = new AtomicBoolean(false);
|
||||
|
||||
RequestExecutorService(
|
||||
String serviceName,
|
||||
ThreadPool threadPool,
|
||||
@Nullable CountDownLatch startupLatch,
|
||||
RequestExecutorServiceSettings settings,
|
||||
SingleRequestManager requestManager
|
||||
RequestSender requestSender
|
||||
) {
|
||||
this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager);
|
||||
}
|
||||
|
||||
/**
|
||||
* This constructor should only be used directly for testing.
|
||||
*/
|
||||
RequestExecutorService(
|
||||
String serviceName,
|
||||
ThreadPool threadPool,
|
||||
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
|
||||
@Nullable CountDownLatch startupLatch,
|
||||
RequestExecutorServiceSettings settings,
|
||||
SingleRequestManager requestManager
|
||||
) {
|
||||
this.serviceName = Objects.requireNonNull(serviceName);
|
||||
this.threadPool = Objects.requireNonNull(threadPool);
|
||||
this.httpContext = HttpClientContext.create();
|
||||
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
|
||||
this.startupLatch = startupLatch;
|
||||
this.requestManager = Objects.requireNonNull(requestManager);
|
||||
|
||||
Objects.requireNonNull(settings);
|
||||
settings.registerQueueCapacityCallback(this::onCapacityChange);
|
||||
}
|
||||
|
||||
private void onCapacityChange(int capacity) {
|
||||
logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity));
|
||||
|
||||
var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity));
|
||||
if (enqueuedCapacityCommand == false) {
|
||||
logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later.");
|
||||
} else {
|
||||
// ensure that the task execution loop wakes up
|
||||
queue.offer(new NoopTask());
|
||||
}
|
||||
}
|
||||
|
||||
private void updateCapacity(int newCapacity) {
|
||||
try {
|
||||
queue.setCapacity(newCapacity);
|
||||
} catch (Exception e) {
|
||||
logger.warn(
|
||||
format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName),
|
||||
e
|
||||
this(
|
||||
threadPool,
|
||||
DEFAULT_QUEUE_CREATOR,
|
||||
startupLatch,
|
||||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
DEFAULT_SLEEPER,
|
||||
DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
}
|
||||
|
||||
RequestExecutorService(
|
||||
ThreadPool threadPool,
|
||||
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator,
|
||||
@Nullable CountDownLatch startupLatch,
|
||||
RequestExecutorServiceSettings settings,
|
||||
RequestSender requestSender,
|
||||
Clock clock,
|
||||
Sleeper sleeper,
|
||||
RateLimiterCreator rateLimiterCreator
|
||||
) {
|
||||
this.threadPool = Objects.requireNonNull(threadPool);
|
||||
this.queueCreator = Objects.requireNonNull(queueCreator);
|
||||
this.startupLatch = startupLatch;
|
||||
this.requestSender = Objects.requireNonNull(requestSender);
|
||||
this.settings = Objects.requireNonNull(settings);
|
||||
this.clock = Objects.requireNonNull(clock);
|
||||
this.sleeper = Objects.requireNonNull(sleeper);
|
||||
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
|
||||
}
|
||||
|
||||
public void shutdown() {
|
||||
if (shutdown.compareAndSet(false, true)) {
|
||||
if (cancellableCleanupTask.get() != null) {
|
||||
logger.debug(() -> "Stopping clean up thread");
|
||||
cancellableCleanupTask.get().cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isShutdown() {
|
||||
return shutdown.get();
|
||||
}
|
||||
|
||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||
return terminationLatch.await(timeout, unit);
|
||||
}
|
||||
|
||||
public boolean isTerminated() {
|
||||
return terminationLatch.getCount() == 0;
|
||||
}
|
||||
|
||||
public int queueSize() {
|
||||
return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum();
|
||||
}
|
||||
|
||||
/**
|
||||
* Begin servicing tasks.
|
||||
* <p>
|
||||
* <b>Note: This should only be called once for the life of the object.</b>
|
||||
* </p>
|
||||
*/
|
||||
public void start() {
|
||||
try {
|
||||
assert started.get() == false : "start() can only be called once";
|
||||
started.set(true);
|
||||
|
||||
startCleanupTask();
|
||||
signalStartInitiated();
|
||||
|
||||
while (running.get()) {
|
||||
while (isShutdown() == false) {
|
||||
handleTasks();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
} finally {
|
||||
running.set(false);
|
||||
shutdown();
|
||||
notifyRequestsOfShutdown();
|
||||
terminationLatch.countDown();
|
||||
}
|
||||
|
@ -156,108 +206,68 @@ class RequestExecutorService implements RequestExecutor {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Protects the task retrieval logic from an unexpected exception.
|
||||
*
|
||||
* @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to
|
||||
* shut down
|
||||
*/
|
||||
private void startCleanupTask() {
|
||||
assert cancellableCleanupTask.get() == null : "The clean up task can only be set once";
|
||||
cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL));
|
||||
}
|
||||
|
||||
private Scheduler.Cancellable startCleanupThread(TimeValue interval) {
|
||||
logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval));
|
||||
|
||||
return threadPool.scheduleWithFixedDelay(this::removeStaleGroupings, interval, threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||
}
|
||||
|
||||
// default for testing
|
||||
void removeStaleGroupings() {
|
||||
var now = Instant.now(clock);
|
||||
for (var iter = rateLimitGroupings.values().iterator(); iter.hasNext();) {
|
||||
var endpoint = iter.next();
|
||||
|
||||
// if the current time is after the last time the endpoint enqueued a request + allowed stale period then we'll remove it
|
||||
if (now.isAfter(endpoint.timeOfLastEnqueue().plus(settings.getRateLimitGroupStaleDuration()))) {
|
||||
endpoint.close();
|
||||
iter.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void handleTasks() throws InterruptedException {
|
||||
try {
|
||||
RejectableTask task = queue.take();
|
||||
|
||||
var command = controlQueue.poll();
|
||||
if (command != null) {
|
||||
command.run();
|
||||
var timeToWait = settings.getTaskPollFrequency();
|
||||
for (var endpoint : rateLimitGroupings.values()) {
|
||||
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
|
||||
}
|
||||
|
||||
// TODO add logic to complete pending items in the queue before shutting down
|
||||
if (running.get() == false) {
|
||||
logger.debug(() -> format("Http executor service [%s] exiting", serviceName));
|
||||
rejectTaskBecauseOfShutdown(task);
|
||||
} else {
|
||||
executeTask(task);
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e);
|
||||
}
|
||||
sleeper.sleep(timeToWait);
|
||||
}
|
||||
|
||||
private void executeTask(RejectableTask task) {
|
||||
try {
|
||||
requestManager.execute(task, httpContext);
|
||||
} catch (Exception e) {
|
||||
logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e);
|
||||
}
|
||||
}
|
||||
|
||||
private synchronized void notifyRequestsOfShutdown() {
|
||||
private void notifyRequestsOfShutdown() {
|
||||
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
|
||||
|
||||
try {
|
||||
List<RejectableTask> notExecuted = new ArrayList<>();
|
||||
queue.drainTo(notExecuted);
|
||||
|
||||
rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown);
|
||||
} catch (Exception e) {
|
||||
logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName));
|
||||
for (var endpoint : rateLimitGroupings.values()) {
|
||||
endpoint.notifyRequestsOfShutdown();
|
||||
}
|
||||
}
|
||||
|
||||
private void rejectTaskBecauseOfShutdown(RejectableTask task) {
|
||||
try {
|
||||
task.onRejection(
|
||||
new EsRejectedExecutionException(
|
||||
format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName),
|
||||
true
|
||||
)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
logger.warn(
|
||||
format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName)
|
||||
);
|
||||
}
|
||||
// default for testing
|
||||
Integer remainingQueueCapacity(RequestManager requestManager) {
|
||||
var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping());
|
||||
|
||||
if (endpoint == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
private void rejectTasks(List<RejectableTask> tasks, Consumer<RejectableTask> rejectionFunction) {
|
||||
for (var task : tasks) {
|
||||
rejectionFunction.accept(task);
|
||||
}
|
||||
return endpoint.remainingCapacity();
|
||||
}
|
||||
|
||||
public int queueSize() {
|
||||
return queue.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (running.compareAndSet(true, false)) {
|
||||
// if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns
|
||||
queue.offer(new NoopTask());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isShutdown() {
|
||||
return running.get() == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTerminated() {
|
||||
return terminationLatch.getCount() == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||
return terminationLatch.await(timeout, unit);
|
||||
// default for testing
|
||||
int numberOfRateLimitGroups() {
|
||||
return rateLimitGroupings.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the request at some point in the future.
|
||||
*
|
||||
* @param requestCreator the http request to send
|
||||
* @param requestManager the http request to send
|
||||
* @param inferenceInputs the inputs to send in the request
|
||||
* @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the
|
||||
* listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}.
|
||||
|
@ -265,13 +275,13 @@ class RequestExecutorService implements RequestExecutor {
|
|||
* @param listener an {@link ActionListener<InferenceServiceResults>} for the response or failure
|
||||
*/
|
||||
public void execute(
|
||||
RequestManager requestCreator,
|
||||
RequestManager requestManager,
|
||||
InferenceInputs inferenceInputs,
|
||||
@Nullable TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var task = new RequestTask(
|
||||
requestCreator,
|
||||
requestManager,
|
||||
inferenceInputs,
|
||||
timeout,
|
||||
threadPool,
|
||||
|
@ -280,13 +290,157 @@ class RequestExecutorService implements RequestExecutor {
|
|||
ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext())
|
||||
);
|
||||
|
||||
completeExecution(task);
|
||||
var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> {
|
||||
var endpointHandler = new RateLimitingEndpointHandler(
|
||||
Integer.toString(requestManager.rateLimitGrouping().hashCode()),
|
||||
queueCreator,
|
||||
settings,
|
||||
requestSender,
|
||||
clock,
|
||||
requestManager.rateLimitSettings(),
|
||||
this::isShutdown,
|
||||
rateLimiterCreator
|
||||
);
|
||||
|
||||
endpointHandler.init();
|
||||
return endpointHandler;
|
||||
});
|
||||
|
||||
endpoint.enqueue(task);
|
||||
}
|
||||
|
||||
private void completeExecution(RequestTask task) {
|
||||
/**
|
||||
* Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests.
|
||||
* This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent
|
||||
* as soon as a thread is available.
|
||||
*/
|
||||
private static class RateLimitingEndpointHandler {
|
||||
|
||||
private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE;
|
||||
private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO;
|
||||
private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class);
|
||||
private static final int ACCUMULATED_TOKENS_LIMIT = 1;
|
||||
|
||||
private final AdjustableCapacityBlockingQueue<RejectableTask> queue;
|
||||
private final Supplier<Boolean> isShutdownMethod;
|
||||
private final RequestSender requestSender;
|
||||
private final String id;
|
||||
private final AtomicReference<Instant> timeOfLastEnqueue = new AtomicReference<>();
|
||||
private final Clock clock;
|
||||
private final RateLimiter rateLimiter;
|
||||
private final RequestExecutorServiceSettings requestExecutorServiceSettings;
|
||||
|
||||
RateLimitingEndpointHandler(
|
||||
String id,
|
||||
AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> createQueue,
|
||||
RequestExecutorServiceSettings settings,
|
||||
RequestSender requestSender,
|
||||
Clock clock,
|
||||
RateLimitSettings rateLimitSettings,
|
||||
Supplier<Boolean> isShutdownMethod,
|
||||
RateLimiterCreator rateLimiterCreator
|
||||
) {
|
||||
this.requestExecutorServiceSettings = Objects.requireNonNull(settings);
|
||||
this.id = Objects.requireNonNull(id);
|
||||
this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity());
|
||||
this.requestSender = Objects.requireNonNull(requestSender);
|
||||
this.clock = Objects.requireNonNull(clock);
|
||||
this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod);
|
||||
|
||||
Objects.requireNonNull(rateLimitSettings);
|
||||
Objects.requireNonNull(rateLimiterCreator);
|
||||
rateLimiter = rateLimiterCreator.create(
|
||||
ACCUMULATED_TOKENS_LIMIT,
|
||||
rateLimitSettings.requestsPerTimeUnit(),
|
||||
rateLimitSettings.timeUnit()
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
public void init() {
|
||||
requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange);
|
||||
}
|
||||
|
||||
private void onCapacityChange(int capacity) {
|
||||
logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity));
|
||||
|
||||
try {
|
||||
queue.setCapacity(capacity);
|
||||
} catch (Exception e) {
|
||||
logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e);
|
||||
}
|
||||
}
|
||||
|
||||
public int queueSize() {
|
||||
return queue.size();
|
||||
}
|
||||
|
||||
public boolean isShutdown() {
|
||||
return isShutdownMethod.get();
|
||||
}
|
||||
|
||||
public Instant timeOfLastEnqueue() {
|
||||
return timeOfLastEnqueue.get();
|
||||
}
|
||||
|
||||
public synchronized TimeValue executeEnqueuedTask() {
|
||||
try {
|
||||
return executeEnqueuedTaskInternal();
|
||||
} catch (Exception e) {
|
||||
logger.warn(format("Executor service grouping [%s] failed to execute request", id), e);
|
||||
// we tried to do some work but failed, so we'll say we did something to try looking for more work
|
||||
return EXECUTED_A_TASK;
|
||||
}
|
||||
}
|
||||
|
||||
private TimeValue executeEnqueuedTaskInternal() {
|
||||
var timeBeforeAvailableToken = rateLimiter.timeToReserve(1);
|
||||
if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) {
|
||||
return timeBeforeAvailableToken;
|
||||
}
|
||||
|
||||
var task = queue.poll();
|
||||
|
||||
// TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks
|
||||
// So we'll need to check for null and call a helper method executePreparedTasks()
|
||||
|
||||
if (shouldExecuteTask(task) == false) {
|
||||
return NO_TASKS_AVAILABLE;
|
||||
}
|
||||
|
||||
// We should never have to wait because we checked above
|
||||
var reserveRes = rateLimiter.reserve(1);
|
||||
assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have";
|
||||
|
||||
task.getRequestManager()
|
||||
.execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener());
|
||||
return EXECUTED_A_TASK;
|
||||
}
|
||||
|
||||
private static boolean shouldExecuteTask(RejectableTask task) {
|
||||
return task != null && isNoopRequest(task) == false && task.hasCompleted() == false;
|
||||
}
|
||||
|
||||
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
|
||||
return inferenceRequest.getRequestManager() == null
|
||||
|| inferenceRequest.getInput() == null
|
||||
|| inferenceRequest.getListener() == null;
|
||||
}
|
||||
|
||||
private static boolean shouldExecuteImmediately(TimeValue delay) {
|
||||
return delay.duration() == 0;
|
||||
}
|
||||
|
||||
public void enqueue(RequestTask task) {
|
||||
timeOfLastEnqueue.set(Instant.now(clock));
|
||||
|
||||
if (isShutdown()) {
|
||||
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
|
||||
format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName),
|
||||
format(
|
||||
"Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown",
|
||||
task.getRequestManager().inferenceEntityId(),
|
||||
id
|
||||
),
|
||||
true
|
||||
);
|
||||
|
||||
|
@ -294,24 +448,72 @@ class RequestExecutorService implements RequestExecutor {
|
|||
return;
|
||||
}
|
||||
|
||||
boolean added = queue.offer(task);
|
||||
if (added == false) {
|
||||
var addedToQueue = queue.offer(task);
|
||||
|
||||
if (addedToQueue == false) {
|
||||
EsRejectedExecutionException rejected = new EsRejectedExecutionException(
|
||||
format("Failed to execute task because the http executor service [%s] queue is full", serviceName),
|
||||
format(
|
||||
"Failed to execute task for inference id [%s] because the request service [%s] queue is full",
|
||||
task.getRequestManager().inferenceEntityId(),
|
||||
id
|
||||
),
|
||||
false
|
||||
);
|
||||
|
||||
task.onRejection(rejected);
|
||||
} else if (isShutdown()) {
|
||||
// It is possible that a shutdown and notification request occurred after we initially checked for shutdown above
|
||||
// If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if
|
||||
// we shut down and if so we'll redo the notification
|
||||
notifyRequestsOfShutdown();
|
||||
}
|
||||
}
|
||||
|
||||
// default for testing
|
||||
int remainingQueueCapacity() {
|
||||
public synchronized void notifyRequestsOfShutdown() {
|
||||
assert isShutdown() : "Requests should only be notified if the executor is shutting down";
|
||||
|
||||
try {
|
||||
List<RejectableTask> notExecuted = new ArrayList<>();
|
||||
queue.drainTo(notExecuted);
|
||||
|
||||
rejectTasks(notExecuted);
|
||||
} catch (Exception e) {
|
||||
logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id));
|
||||
}
|
||||
}
|
||||
|
||||
private void rejectTasks(List<RejectableTask> tasks) {
|
||||
for (var task : tasks) {
|
||||
rejectTaskForShutdown(task);
|
||||
}
|
||||
}
|
||||
|
||||
private void rejectTaskForShutdown(RejectableTask task) {
|
||||
try {
|
||||
task.onRejection(
|
||||
new EsRejectedExecutionException(
|
||||
format(
|
||||
"Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request",
|
||||
id,
|
||||
task.getRequestManager().inferenceEntityId()
|
||||
),
|
||||
true
|
||||
)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
logger.warn(
|
||||
format(
|
||||
"Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown",
|
||||
task.getRequestManager().inferenceEntityId(),
|
||||
id
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public int remainingCapacity() {
|
||||
return queue.remainingCapacity();
|
||||
}
|
||||
|
||||
public void close() {
|
||||
requestExecutorServiceSettings.deregisterQueueCapacityCallback(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,9 +10,12 @@ package org.elasticsearch.xpack.inference.external.http.sender;
|
|||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.settings.Setting;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class RequestExecutorServiceSettings {
|
||||
|
@ -29,37 +32,108 @@ public class RequestExecutorServiceSettings {
|
|||
Setting.Property.Dynamic
|
||||
);
|
||||
|
||||
private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50);
|
||||
/**
|
||||
* Defines how often all the rate limit groups are polled for tasks. Setting this to very low number could result
|
||||
* in a busy loop if there are no tasks available to handle.
|
||||
*/
|
||||
static final Setting<TimeValue> TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting(
|
||||
"xpack.inference.http.request_executor.task_poll_frequency",
|
||||
DEFAULT_TASK_POLL_FREQUENCY_TIME,
|
||||
Setting.Property.NodeScope,
|
||||
Setting.Property.Dynamic
|
||||
);
|
||||
|
||||
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
|
||||
/**
|
||||
* Defines how often a thread will check for rate limit groups that are stale.
|
||||
*/
|
||||
static final Setting<TimeValue> RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING = Setting.timeSetting(
|
||||
"xpack.inference.http.request_executor.rate_limit_group_cleanup_interval",
|
||||
DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL,
|
||||
Setting.Property.NodeScope,
|
||||
Setting.Property.Dynamic
|
||||
);
|
||||
|
||||
private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION = TimeValue.timeValueDays(10);
|
||||
/**
|
||||
* Defines the amount of time it takes to classify a rate limit group as stale. Once it is classified as stale,
|
||||
* it can be removed when the cleanup thread executes.
|
||||
*/
|
||||
static final Setting<TimeValue> RATE_LIMIT_GROUP_STALE_DURATION_SETTING = Setting.timeSetting(
|
||||
"xpack.inference.http.request_executor.rate_limit_group_stale_duration",
|
||||
DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION,
|
||||
Setting.Property.NodeScope,
|
||||
Setting.Property.Dynamic
|
||||
);
|
||||
|
||||
public static List<Setting<?>> getSettingsDefinitions() {
|
||||
return List.of(TASK_QUEUE_CAPACITY_SETTING);
|
||||
return List.of(
|
||||
TASK_QUEUE_CAPACITY_SETTING,
|
||||
TASK_POLL_FREQUENCY_SETTING,
|
||||
RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING,
|
||||
RATE_LIMIT_GROUP_STALE_DURATION_SETTING
|
||||
);
|
||||
}
|
||||
|
||||
private volatile int queueCapacity;
|
||||
private final List<Consumer<Integer>> queueCapacityCallbacks = new ArrayList<Consumer<Integer>>();
|
||||
private volatile TimeValue taskPollFrequency;
|
||||
private volatile Duration rateLimitGroupStaleDuration;
|
||||
private final ConcurrentMap<String, Consumer<Integer>> queueCapacityCallbacks = new ConcurrentHashMap<>();
|
||||
|
||||
public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) {
|
||||
queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings);
|
||||
taskPollFrequency = TASK_POLL_FREQUENCY_SETTING.get(settings);
|
||||
setRateLimitGroupStaleDuration(RATE_LIMIT_GROUP_STALE_DURATION_SETTING.get(settings));
|
||||
|
||||
addSettingsUpdateConsumers(clusterService);
|
||||
}
|
||||
|
||||
private void addSettingsUpdateConsumers(ClusterService clusterService) {
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity);
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency);
|
||||
clusterService.getClusterSettings()
|
||||
.addSettingsUpdateConsumer(RATE_LIMIT_GROUP_STALE_DURATION_SETTING, this::setRateLimitGroupStaleDuration);
|
||||
}
|
||||
|
||||
// default for testing
|
||||
void setQueueCapacity(int queueCapacity) {
|
||||
this.queueCapacity = queueCapacity;
|
||||
|
||||
for (var callback : queueCapacityCallbacks) {
|
||||
for (var callback : queueCapacityCallbacks.values()) {
|
||||
callback.accept(queueCapacity);
|
||||
}
|
||||
}
|
||||
|
||||
void registerQueueCapacityCallback(Consumer<Integer> onChangeCapacityCallback) {
|
||||
queueCapacityCallbacks.add(onChangeCapacityCallback);
|
||||
private void setTaskPollFrequency(TimeValue taskPollFrequency) {
|
||||
this.taskPollFrequency = taskPollFrequency;
|
||||
}
|
||||
|
||||
private void setRateLimitGroupStaleDuration(TimeValue staleDuration) {
|
||||
rateLimitGroupStaleDuration = toDuration(staleDuration);
|
||||
}
|
||||
|
||||
private static Duration toDuration(TimeValue timeValue) {
|
||||
return Duration.of(timeValue.duration(), timeValue.timeUnit().toChronoUnit());
|
||||
}
|
||||
|
||||
void registerQueueCapacityCallback(String id, Consumer<Integer> onChangeCapacityCallback) {
|
||||
queueCapacityCallbacks.put(id, onChangeCapacityCallback);
|
||||
}
|
||||
|
||||
void deregisterQueueCapacityCallback(String id) {
|
||||
queueCapacityCallbacks.remove(id);
|
||||
}
|
||||
|
||||
int getQueueCapacity() {
|
||||
return queueCapacity;
|
||||
}
|
||||
|
||||
TimeValue getTaskPollFrequency() {
|
||||
return taskPollFrequency;
|
||||
}
|
||||
|
||||
Duration getRateLimitGroupStaleDuration() {
|
||||
return rateLimitGroupStaleDuration;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
|
@ -21,14 +20,17 @@ import java.util.function.Supplier;
|
|||
* A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service.
|
||||
*/
|
||||
public interface RequestManager extends RateLimitable {
|
||||
Runnable create(
|
||||
void execute(
|
||||
@Nullable String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
HttpClientContext context,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
);
|
||||
|
||||
// TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual
|
||||
// managers to implement their own batching
|
||||
// executePreparedRequest() which will execute all prepared requests aka sends the batch
|
||||
|
||||
String inferenceEntityId();
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ class RequestTask implements RejectableTask {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RequestManager getRequestCreator() {
|
||||
public RequestManager getRequestManager() {
|
||||
return requestCreator;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Handles executing a single inference request at a time.
|
||||
*/
|
||||
public class SingleRequestManager {
|
||||
|
||||
protected RetryingHttpSender requestSender;
|
||||
|
||||
public SingleRequestManager(RetryingHttpSender requestSender) {
|
||||
this.requestSender = Objects.requireNonNull(requestSender);
|
||||
}
|
||||
|
||||
public void execute(InferenceRequest inferenceRequest, HttpClientContext context) {
|
||||
if (isNoopRequest(inferenceRequest) || inferenceRequest.hasCompleted()) {
|
||||
return;
|
||||
}
|
||||
|
||||
inferenceRequest.getRequestCreator()
|
||||
.create(
|
||||
inferenceRequest.getQuery(),
|
||||
inferenceRequest.getInput(),
|
||||
requestSender,
|
||||
inferenceRequest.getRequestCompletedFunction(),
|
||||
context,
|
||||
inferenceRequest.getListener()
|
||||
)
|
||||
.run();
|
||||
}
|
||||
|
||||
private static boolean isNoopRequest(InferenceRequest inferenceRequest) {
|
||||
return inferenceRequest.getRequestCreator() == null
|
||||
|| inferenceRequest.getInput() == null
|
||||
|| inferenceRequest.getListener() == null;
|
||||
}
|
||||
}
|
|
@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService {
|
|||
|
||||
public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
Objects.requireNonNull(factory);
|
||||
sender = factory.createSender(name());
|
||||
sender = factory.createSender();
|
||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||
}
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completio
|
|||
|
||||
public class AzureAiStudioService extends SenderService {
|
||||
|
||||
private static final String NAME = "azureaistudio";
|
||||
static final String NAME = "azureaistudio";
|
||||
|
||||
public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
|
|
|
@ -44,7 +44,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
|
|||
ConfigurationParseContext context
|
||||
) {
|
||||
String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
AzureAiStudioService.NAME,
|
||||
context
|
||||
);
|
||||
AzureAiStudioEndpointType endpointType = extractRequiredEnum(
|
||||
map,
|
||||
ENDPOINT_TYPE_FIELD,
|
||||
|
@ -118,13 +124,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec
|
|||
|
||||
protected void addXContentFields(XContentBuilder builder, Params params) throws IOException {
|
||||
this.addExposedXContentFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
}
|
||||
|
||||
protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(TARGET_FIELD, this.target);
|
||||
builder.field(PROVIDER_FIELD, this.provider);
|
||||
builder.field(ENDPOINT_TYPE_FIELD, this.endpointType);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -135,7 +135,15 @@ public class AzureOpenAiService extends SenderService {
|
|||
);
|
||||
}
|
||||
case COMPLETION -> {
|
||||
return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
|
||||
return new AzureOpenAiCompletionModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
}
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType;
|
|||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
|
||||
|
||||
|
@ -37,13 +38,14 @@ public class AzureOpenAiCompletionModel extends AzureOpenAiModel {
|
|||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
@Nullable Map<String, Object> secrets
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings),
|
||||
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
|
||||
AzureOpenAiSecretSettings.fromMap(secrets)
|
||||
);
|
||||
|
|
|
@ -17,7 +17,9 @@ import org.elasticsearch.inference.ModelConfigurations;
|
|||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -55,10 +57,10 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
|||
*/
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
|
||||
|
||||
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
var settings = fromMap(map, validationException);
|
||||
var settings = fromMap(map, validationException, context);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -69,12 +71,19 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
|||
|
||||
private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap(
|
||||
Map<String, Object> map,
|
||||
ValidationException validationException
|
||||
ValidationException validationException,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
AzureOpenAiService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings);
|
||||
}
|
||||
|
@ -137,7 +146,6 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -148,6 +156,7 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
|
|||
builder.field(RESOURCE_NAME, resourceName);
|
||||
builder.field(DEPLOYMENT_ID, deploymentId);
|
||||
builder.field(API_VERSION, apiVersion);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -90,7 +91,13 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
|
|||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||
Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
AzureOpenAiService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
|
||||
|
||||
|
@ -245,8 +252,6 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
|
||||
|
||||
builder.endObject();
|
||||
|
@ -268,6 +273,7 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
|
|||
if (similarity != null) {
|
||||
builder.field(SIMILARITY, similarity);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -51,6 +51,11 @@ import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFie
|
|||
public class CohereService extends SenderService {
|
||||
public static final String NAME = "cohere";
|
||||
|
||||
// TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to
|
||||
// the Cohere*RequestManager via the CohereActionCreator class
|
||||
// The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated
|
||||
// on every request
|
||||
|
||||
public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
}
|
||||
|
@ -131,7 +136,15 @@ public class CohereService extends SenderService {
|
|||
context
|
||||
);
|
||||
case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
|
||||
case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
|
||||
case COMPLETION -> new CohereCompletionModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -58,7 +58,13 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
|
|||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
CohereService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
||||
|
@ -173,10 +179,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
|
|||
}
|
||||
|
||||
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
return toXContentFragmentOfExposedFields(builder, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -196,6 +199,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
|
|||
if (modelId != null) {
|
||||
builder.field(MODEL_ID, modelId);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskSettings;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
|
@ -30,13 +31,14 @@ public class CohereCompletionModel extends CohereModel {
|
|||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
@Nullable Map<String, Object> secrets
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
modelId,
|
||||
taskType,
|
||||
service,
|
||||
CohereCompletionServiceSettings.fromMap(serviceSettings),
|
||||
CohereCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
EmptyTaskSettings.INSTANCE,
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
|
|
|
@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -39,12 +41,18 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
|
|||
// 10K requests per minute
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
|
||||
|
||||
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
CohereService.NAME,
|
||||
context
|
||||
);
|
||||
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
|
@ -94,7 +102,6 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -127,6 +134,7 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
|
|||
if (modelId != null) {
|
||||
builder.field(MODEL_ID, modelId);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -108,7 +108,8 @@ public class GoogleAiStudioService extends SenderService {
|
|||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
secretSettings
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel(
|
||||
inferenceEntityId,
|
||||
|
@ -116,7 +117,8 @@ public class GoogleAiStudioService extends SenderService {
|
|||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
secretSettings
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
|
|||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
|
@ -37,13 +38,14 @@ public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel {
|
|||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
Map<String, Object> secrets
|
||||
Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings),
|
||||
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
EmptyTaskSettings.INSTANCE,
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
|
|
|
@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -40,11 +42,17 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
|
|||
*/
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
|
||||
|
||||
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
GoogleAiStudioService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -82,7 +90,6 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -107,6 +114,7 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj
|
|||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(MODEL_ID, modelId);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType;
|
|||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
|
@ -37,13 +38,14 @@ public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel {
|
|||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
Map<String, Object> secrets
|
||||
Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings),
|
||||
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
||||
EmptyTaskSettings.INSTANCE,
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
|
|
|
@ -18,7 +18,9 @@ import org.elasticsearch.inference.ModelConfigurations;
|
|||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -47,7 +49,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
|||
*/
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
|
||||
|
||||
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
@ -59,7 +61,13 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
|||
);
|
||||
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
GoogleAiStudioService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -134,7 +142,6 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -174,6 +181,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj
|
|||
if (similarity != null) {
|
||||
builder.field(SIMILARITY, similarity);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
|||
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.SenderService;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
|
||||
|
@ -62,7 +63,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
taskType,
|
||||
serviceSettingsMap,
|
||||
serviceSettingsMap,
|
||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, name())
|
||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
throwIfNotEmptyMap(config, name());
|
||||
|
@ -89,7 +91,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
taskType,
|
||||
serviceSettingsMap,
|
||||
secretSettingsMap,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name())
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -97,7 +100,14 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
|
||||
return createModel(inferenceEntityId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(inferenceEntityId, name()));
|
||||
return createModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
null,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
}
|
||||
|
||||
protected abstract HuggingFaceModel createModel(
|
||||
|
@ -105,7 +115,8 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> secretSettings,
|
||||
String failureMessage
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
);
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.inference.Model;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
|
@ -36,11 +37,19 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secretSettings,
|
||||
String failureMessage
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
return switch (taskType) {
|
||||
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
|
||||
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.ModelConfigurations;
|
|||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -43,14 +44,20 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
|
|||
// 3000 requests per minute
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||
|
||||
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static HuggingFaceServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
var uri = extractUri(map, URL, validationException);
|
||||
|
||||
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
HuggingFaceService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -119,7 +126,6 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -136,6 +142,7 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement
|
|||
if (maxInputTokens != null) {
|
||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
|
@ -24,13 +25,14 @@ public class HuggingFaceElserModel extends HuggingFaceModel {
|
|||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secrets
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
HuggingFaceElserServiceSettings.fromMap(serviceSettings),
|
||||
HuggingFaceElserServiceSettings.fromMap(serviceSettings, context),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||
|
@ -38,10 +39,11 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
|
|||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secretSettings,
|
||||
String failureMessage
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
return switch (taskType) {
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings);
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -15,7 +15,9 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -40,10 +42,16 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
|
|||
// 3000 requests per minute
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||
|
||||
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static HuggingFaceElserServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
var uri = extractUri(map, URL, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
HuggingFaceService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -93,7 +101,6 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
builder.endObject();
|
||||
|
||||
return builder;
|
||||
|
@ -103,6 +110,7 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject
|
|||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(URL, uri.toString());
|
||||
builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
@ -25,13 +26,14 @@ public class HuggingFaceEmbeddingsModel extends HuggingFaceModel {
|
|||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secrets
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
HuggingFaceServiceSettings.fromMap(serviceSettings),
|
||||
HuggingFaceServiceSettings.fromMap(serviceSettings, context),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.inference.ServiceSettings;
|
|||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -59,7 +60,13 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
|
|||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
MistralService.NAME,
|
||||
context
|
||||
);
|
||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
|
@ -141,7 +148,6 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
this.toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -159,6 +165,7 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
|
|||
if (this.maxInputTokens != null) {
|
||||
builder.field(MAX_INPUT_TOKENS, this.maxInputTokens);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -138,7 +138,8 @@ public class OpenAiService extends SenderService {
|
|||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
secretSettings
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets;
|
|||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
|
@ -35,13 +36,14 @@ public class OpenAiChatCompletionModel extends OpenAiModel {
|
|||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
@Nullable Map<String, Object> secrets
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings),
|
||||
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
|
|
|
@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -47,7 +49,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
|||
// 500 requests per minute
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
|
||||
|
||||
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map) {
|
||||
public static OpenAiChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
@ -58,7 +60,13 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
|||
|
||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
OpenAiService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -142,7 +150,6 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -163,6 +170,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject
|
|||
if (maxInputTokens != null) {
|
||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.inference.SimilarityMeasure;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -66,7 +67,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
|||
// passed at that time and never throw.
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
var commonFields = fromMap(map, validationException);
|
||||
var commonFields = fromMap(map, validationException, ConfigurationParseContext.PERSISTENT);
|
||||
|
||||
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
|
||||
if (dimensionsSetByUser == null) {
|
||||
|
@ -80,7 +81,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
|||
private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
var commonFields = fromMap(map, validationException);
|
||||
var commonFields = fromMap(map, validationException, ConfigurationParseContext.REQUEST);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
|
@ -89,7 +90,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
|||
return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null);
|
||||
}
|
||||
|
||||
private static CommonFields fromMap(Map<String, Object> map, ValidationException validationException) {
|
||||
private static CommonFields fromMap(
|
||||
Map<String, Object> map,
|
||||
ValidationException validationException,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
|
||||
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
@ -98,7 +103,13 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
|||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
OpenAiService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings);
|
||||
}
|
||||
|
@ -258,7 +269,6 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
|||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
if (dimensionsSetByUser != null) {
|
||||
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
|
||||
|
@ -286,6 +296,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
|
|||
if (maxInputTokens != null) {
|
||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||
}
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xcontent.ToXContentFragment;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
@ -21,19 +22,29 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||
|
||||
public class RateLimitSettings implements Writeable, ToXContentFragment {
|
||||
|
||||
public static final String FIELD_NAME = "rate_limit";
|
||||
public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute";
|
||||
|
||||
private final long requestsPerTimeUnit;
|
||||
private final TimeUnit timeUnit;
|
||||
|
||||
public static RateLimitSettings of(Map<String, Object> map, RateLimitSettings defaultValue, ValidationException validationException) {
|
||||
public static RateLimitSettings of(
|
||||
Map<String, Object> map,
|
||||
RateLimitSettings defaultValue,
|
||||
ValidationException validationException,
|
||||
String serviceName,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
Map<String, Object> settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME);
|
||||
var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
|
||||
|
||||
if (ConfigurationParseContext.isRequestContext(context)) {
|
||||
throwIfNotEmptyMap(settings, serviceName);
|
||||
}
|
||||
|
||||
return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
|
@ -92,7 +93,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
TruncatorTests.createTruncator()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
|
||||
|
@ -141,7 +142,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
TruncatorTests.createTruncator()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
|
||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
||||
|
@ -82,7 +83,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -132,7 +133,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -183,7 +184,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
// timeout as zero for no retries
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -237,7 +238,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
// note - there is no complete documentation on Azure's error messages
|
||||
|
@ -313,7 +314,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
// note - there is no complete documentation on Azure's error messages
|
||||
|
@ -389,7 +390,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -440,7 +441,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -498,7 +499,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -554,7 +555,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|||
// timeout as zero for no retries
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
// "choices" missing
|
||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|||
import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
@ -77,7 +78,7 @@ public class AzureOpenAiCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
|
||||
|
@ -81,7 +82,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
|
|||
mockClusterServiceEmpty()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
|
@ -73,7 +74,7 @@ public class CohereActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_CohereEmbeddingsModel() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -154,7 +155,7 @@ public class CohereActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -214,7 +215,7 @@ public class CohereActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -77,7 +77,7 @@ public class CohereCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -138,7 +138,7 @@ public class CohereCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -290,7 +290,7 @@ public class CohereCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -81,7 +81,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -162,7 +162,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -74,7 +74,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -206,7 +206,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -79,7 +79,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
|
|||
var input = "input";
|
||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = senderFactory.createSender()) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
|
@ -75,7 +76,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -131,7 +132,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -187,7 +188,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -239,7 +240,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
// this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]]
|
||||
|
@ -292,7 +293,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJsonContentTooLarge = """
|
||||
|
@ -357,7 +358,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
|
@ -74,7 +75,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_OpenAiEmbeddingsModel() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -127,7 +128,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -179,7 +180,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -238,7 +239,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -292,7 +293,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_OpenAiChatCompletionModel() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -355,7 +356,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -417,7 +418,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -486,7 +487,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -552,7 +553,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
var contentTooLargeErrorMessage =
|
||||
|
@ -635,7 +636,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
var contentTooLargeErrorMessage =
|
||||
|
@ -718,7 +719,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
|
|||
public void testExecute_TruncatesInputBeforeSending() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
|
@ -80,7 +81,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -234,7 +235,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase {
|
|||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -79,7 +79,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
|
|||
mockClusterServiceEmpty()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = senderFactory.createSender()) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class BaseRequestManagerTests extends ESTestCase {
|
||||
public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() {
|
||||
int val1 = 1;
|
||||
int val2 = 1;
|
||||
|
||||
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
|
||||
@Override
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) {
|
||||
@Override
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
assertThat(manager1.rateLimitGrouping(), is(manager2.rateLimitGrouping()));
|
||||
}
|
||||
|
||||
public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() {
|
||||
int val1 = 1;
|
||||
|
||||
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) {
|
||||
@Override
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) {
|
||||
@Override
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
|
||||
}
|
||||
|
||||
public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() {
|
||||
int val1 = 1;
|
||||
|
||||
var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) {
|
||||
@Override
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) {
|
||||
@Override
|
||||
public void execute(
|
||||
String query,
|
||||
List<String> input,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping()));
|
||||
}
|
||||
}
|
|
@ -79,7 +79,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
|
||||
var senderFactory = createSenderFactory(clientManager, threadRef);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
|
@ -135,11 +135,11 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
mockClusterServiceEmpty()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = senderFactory.createSender()) {
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
var thrownException = expectThrows(
|
||||
AssertionError.class,
|
||||
() -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
|
||||
() -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener)
|
||||
);
|
||||
assertThat(thrownException.getMessage(), is("call start() before sending a request"));
|
||||
}
|
||||
|
@ -155,17 +155,12 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
mockClusterServiceEmpty()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = senderFactory.createSender()) {
|
||||
assertThat(sender, instanceOf(HttpRequestSender.class));
|
||||
sender.start();
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
sender.send(
|
||||
ExecutableRequestCreatorTests.createMock(),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
TimeValue.timeValueNanos(1),
|
||||
listener
|
||||
);
|
||||
sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
|
@ -186,16 +181,11 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
mockClusterServiceEmpty()
|
||||
);
|
||||
|
||||
try (var sender = senderFactory.createSender("test_service")) {
|
||||
try (var sender = senderFactory.createSender()) {
|
||||
sender.start();
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
sender.send(
|
||||
ExecutableRequestCreatorTests.createMock(),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
TimeValue.timeValueNanos(1),
|
||||
listener
|
||||
);
|
||||
sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
|
@ -220,6 +210,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService);
|
||||
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||
when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class));
|
||||
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
|
||||
|
||||
return new HttpRequestSender.Factory(
|
||||
ServiceComponentsTests.createWithEmptySettings(mockThreadPool),
|
||||
|
@ -248,7 +239,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) {
|
||||
return factory.createSender(serviceName);
|
||||
public static Sender createSender(HttpRequestSender.Factory factory) {
|
||||
return factory.createSender();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.external.http.sender;
|
|||
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterService;
|
||||
|
||||
|
@ -18,12 +19,23 @@ public class RequestExecutorServiceSettingsTests {
|
|||
}
|
||||
|
||||
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) {
|
||||
return createRequestExecutorServiceSettings(queueCapacity, null);
|
||||
}
|
||||
|
||||
public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(
|
||||
@Nullable Integer queueCapacity,
|
||||
@Nullable TimeValue staleDuration
|
||||
) {
|
||||
var settingsBuilder = Settings.builder();
|
||||
|
||||
if (queueCapacity != null) {
|
||||
settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity);
|
||||
}
|
||||
|
||||
if (staleDuration != null) {
|
||||
settingsBuilder.put(RequestExecutorServiceSettings.RATE_LIMIT_GROUP_STALE_DURATION_SETTING.getKey(), staleDuration);
|
||||
}
|
||||
|
||||
return createRequestExecutorServiceSettings(settingsBuilder.build());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,13 +18,19 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.Scheduler;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.common.RateLimiter;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
@ -42,10 +48,13 @@ import static org.elasticsearch.xpack.inference.external.http.sender.RequestExec
|
|||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class RequestExecutorServiceTests extends ESTestCase {
|
||||
|
@ -70,7 +79,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
|
||||
public void testQueueSize_IsOne() {
|
||||
var service = createRequestExecutorServiceWithMocks();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||
|
||||
assertThat(service.queueSize(), is(1));
|
||||
}
|
||||
|
@ -92,7 +101,20 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
assertTrue(service.isTerminated());
|
||||
}
|
||||
|
||||
public void testIsTerminated_AfterStopFromSeparateThread() throws Exception {
|
||||
public void testCallingStartTwice_ThrowsAssertionException() throws InterruptedException {
|
||||
var latch = new CountDownLatch(1);
|
||||
var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class));
|
||||
|
||||
service.shutdown();
|
||||
service.start();
|
||||
latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
||||
|
||||
assertTrue(service.isTerminated());
|
||||
var exception = expectThrows(AssertionError.class, service::start);
|
||||
assertThat(exception.getMessage(), is("start() can only be called once"));
|
||||
}
|
||||
|
||||
public void testIsTerminated_AfterStopFromSeparateThread() {
|
||||
var waitToShutdown = new CountDownLatch(1);
|
||||
var waitToReturnFromSend = new CountDownLatch(1);
|
||||
|
||||
|
@ -127,41 +149,48 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
assertTrue(service.isTerminated());
|
||||
}
|
||||
|
||||
public void testSend_AfterShutdown_Throws() {
|
||||
public void testExecute_AfterShutdown_Throws() {
|
||||
var service = createRequestExecutorServiceWithMocks();
|
||||
|
||||
service.shutdown();
|
||||
|
||||
var requestManager = RequestManagerTests.createMock("id");
|
||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to enqueue task because the http executor service [test_service] has already shutdown")
|
||||
is(
|
||||
Strings.format(
|
||||
"Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown",
|
||||
requestManager.rateLimitGrouping().hashCode()
|
||||
)
|
||||
)
|
||||
);
|
||||
assertTrue(thrownException.isExecutorShutdown());
|
||||
}
|
||||
|
||||
public void testSend_Throws_WhenQueueIsFull() {
|
||||
var service = new RequestExecutorService(
|
||||
"test_service",
|
||||
threadPool,
|
||||
null,
|
||||
createRequestExecutorServiceSettings(1),
|
||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
||||
);
|
||||
public void testExecute_Throws_WhenQueueIsFull() {
|
||||
var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class));
|
||||
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||
|
||||
var requestManager = RequestManagerTests.createMock("id");
|
||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to execute task because the http executor service [test_service] queue is full")
|
||||
is(
|
||||
Strings.format(
|
||||
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
|
||||
requestManager.rateLimitGrouping().hashCode()
|
||||
)
|
||||
)
|
||||
);
|
||||
assertFalse(thrownException.isExecutorShutdown());
|
||||
}
|
||||
|
@ -203,16 +232,11 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
assertTrue(service.isShutdown());
|
||||
}
|
||||
|
||||
public void testSend_CallsOnFailure_WhenRequestTimesOut() {
|
||||
public void testExecute_CallsOnFailure_WhenRequestTimesOut() {
|
||||
var service = createRequestExecutorServiceWithMocks();
|
||||
|
||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||
service.execute(
|
||||
ExecutableRequestCreatorTests.createMock(),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
TimeValue.timeValueNanos(1),
|
||||
listener
|
||||
);
|
||||
service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
|
@ -222,7 +246,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
public void testExecute_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
var headerKey = "not empty";
|
||||
var headerValue = "value";
|
||||
|
||||
|
@ -270,7 +294,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
}
|
||||
};
|
||||
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
Future<?> executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service);
|
||||
|
||||
|
@ -280,11 +304,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
public void testSend_NotifiesTasksOfShutdown() {
|
||||
public void testExecute_NotifiesTasksOfShutdown() {
|
||||
var service = createRequestExecutorServiceWithMocks();
|
||||
|
||||
var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id");
|
||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
service.shutdown();
|
||||
service.start();
|
||||
|
@ -293,47 +318,62 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
|
||||
is(
|
||||
Strings.format(
|
||||
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
|
||||
requestManager.rateLimitGrouping().hashCode()
|
||||
)
|
||||
)
|
||||
);
|
||||
assertTrue(thrownException.isExecutorShutdown());
|
||||
assertTrue(service.isTerminated());
|
||||
}
|
||||
|
||||
public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
|
||||
public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException {
|
||||
@SuppressWarnings("unchecked")
|
||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
||||
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
|
||||
var service = new RequestExecutorService(
|
||||
getTestName(),
|
||||
threadPool,
|
||||
mockQueueCreator(queue),
|
||||
null,
|
||||
createRequestExecutorServiceSettingsEmpty(),
|
||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
|
||||
when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
var requestManager = RequestManagerTests.createMock(requestSender, "id");
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> {
|
||||
service.shutdown();
|
||||
return null;
|
||||
});
|
||||
service.start();
|
||||
|
||||
assertTrue(service.isTerminated());
|
||||
verify(queue, times(2)).take();
|
||||
}
|
||||
|
||||
public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception {
|
||||
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
|
||||
@SuppressWarnings("unchecked")
|
||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
||||
when(queue.take()).thenThrow(new InterruptedException("failed"));
|
||||
var sleeper = mock(RequestExecutorService.Sleeper.class);
|
||||
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
|
||||
|
||||
var service = new RequestExecutorService(
|
||||
getTestName(),
|
||||
threadPool,
|
||||
mockQueueCreator(queue),
|
||||
null,
|
||||
createRequestExecutorServiceSettingsEmpty(),
|
||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
||||
mock(RetryingHttpSender.class),
|
||||
Clock.systemUTC(),
|
||||
sleeper,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
|
||||
Future<?> executorTermination = threadPool.generic().submit(() -> {
|
||||
|
@ -347,66 +387,30 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||
|
||||
assertTrue(service.isTerminated());
|
||||
verify(queue, times(1)).take();
|
||||
}
|
||||
|
||||
public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception {
|
||||
var mockTask = mock(RejectableTask.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
||||
|
||||
var service = new RequestExecutorService(
|
||||
"test_service",
|
||||
threadPool,
|
||||
mockQueueCreator(queue),
|
||||
null,
|
||||
createRequestExecutorServiceSettingsEmpty(),
|
||||
new SingleRequestManager(mock(RetryingHttpSender.class))
|
||||
);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
service.shutdown();
|
||||
return mockTask;
|
||||
}).doReturn(new NoopTask()).when(queue).take();
|
||||
|
||||
service.start();
|
||||
|
||||
assertTrue(service.isTerminated());
|
||||
verify(queue, times(1)).take();
|
||||
|
||||
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
|
||||
verify(mockTask, times(1)).onRejection(argument.capture());
|
||||
assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class));
|
||||
assertThat(
|
||||
argument.getValue().getMessage(),
|
||||
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
|
||||
);
|
||||
|
||||
var rejectionException = (EsRejectedExecutionException) argument.getValue();
|
||||
assertTrue(rejectionException.isExecutorShutdown());
|
||||
}
|
||||
|
||||
public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException {
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
|
||||
var settings = createRequestExecutorServiceSettings(1);
|
||||
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
|
||||
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
|
||||
|
||||
service.execute(
|
||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
null,
|
||||
new PlainActionFuture<>()
|
||||
);
|
||||
service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||
assertThat(service.queueSize(), is(1));
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender, "id");
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to execute task because the http executor service [test_service] queue is full")
|
||||
is(
|
||||
Strings.format(
|
||||
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
|
||||
requestManager.rateLimitGrouping().hashCode()
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
settings.setQueueCapacity(2);
|
||||
|
@ -426,7 +430,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
|
||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||
assertTrue(service.isTerminated());
|
||||
assertThat(service.remainingQueueCapacity(), is(2));
|
||||
assertThat(service.remainingQueueCapacity(requestManager), is(2));
|
||||
}
|
||||
|
||||
public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException,
|
||||
|
@ -434,23 +438,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
var requestSender = mock(RetryingHttpSender.class);
|
||||
|
||||
var settings = createRequestExecutorServiceSettings(3);
|
||||
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
|
||||
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
|
||||
|
||||
service.execute(
|
||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
||||
RequestManagerTests.createMock(requestSender, "id"),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
null,
|
||||
new PlainActionFuture<>()
|
||||
);
|
||||
service.execute(
|
||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
||||
RequestManagerTests.createMock(requestSender, "id"),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
null,
|
||||
new PlainActionFuture<>()
|
||||
);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender, "id");
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
assertThat(service.queueSize(), is(3));
|
||||
|
||||
settings.setQueueCapacity(1);
|
||||
|
@ -470,7 +475,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
|
||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||
assertTrue(service.isTerminated());
|
||||
assertThat(service.remainingQueueCapacity(), is(1));
|
||||
assertThat(service.remainingQueueCapacity(requestManager), is(1));
|
||||
assertThat(service.queueSize(), is(0));
|
||||
|
||||
var thrownException = expectThrows(
|
||||
|
@ -479,7 +484,12 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
);
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to send request, queue service [test_service] has shutdown prior to executing request")
|
||||
is(
|
||||
Strings.format(
|
||||
"Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request",
|
||||
requestManager.rateLimitGrouping().hashCode()
|
||||
)
|
||||
)
|
||||
);
|
||||
assertTrue(thrownException.isExecutorShutdown());
|
||||
}
|
||||
|
@ -489,23 +499,24 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
var requestSender = mock(RetryingHttpSender.class);
|
||||
|
||||
var settings = createRequestExecutorServiceSettings(1);
|
||||
var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender));
|
||||
var service = new RequestExecutorService(threadPool, null, settings, requestSender);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||
|
||||
service.execute(
|
||||
ExecutableRequestCreatorTests.createMock(requestSender),
|
||||
new DocumentsOnlyInput(List.of()),
|
||||
null,
|
||||
new PlainActionFuture<>()
|
||||
);
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>());
|
||||
assertThat(service.queueSize(), is(1));
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
service.execute(RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to execute task because the http executor service [test_service] queue is full")
|
||||
is(
|
||||
Strings.format(
|
||||
"Failed to execute task for inference id [id] because the request service [%s] queue is full",
|
||||
requestManager.rateLimitGrouping().hashCode()
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
settings.setQueueCapacity(0);
|
||||
|
@ -525,7 +536,133 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
|
||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||
assertTrue(service.isTerminated());
|
||||
assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE));
|
||||
assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE));
|
||||
}
|
||||
|
||||
public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
|
||||
var mockRateLimiter = mock(RateLimiter.class);
|
||||
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
|
||||
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
var settings = createRequestExecutorServiceSettings(1);
|
||||
var service = new RequestExecutorService(
|
||||
threadPool,
|
||||
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||
null,
|
||||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
rateLimiterCreator
|
||||
);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
service.shutdown();
|
||||
return TimeValue.timeValueDays(1);
|
||||
}).when(mockRateLimiter).timeToReserve(anyInt());
|
||||
|
||||
service.start();
|
||||
|
||||
verifyNoInteractions(requestSender);
|
||||
}
|
||||
|
||||
public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() {
|
||||
var mockRateLimiter = mock(RateLimiter.class);
|
||||
when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0));
|
||||
|
||||
RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter;
|
||||
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
var settings = createRequestExecutorServiceSettings(1);
|
||||
var service = new RequestExecutorService(
|
||||
threadPool,
|
||||
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||
null,
|
||||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
rateLimiterCreator
|
||||
);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0));
|
||||
|
||||
doAnswer(invocation -> {
|
||||
service.shutdown();
|
||||
return Void.TYPE;
|
||||
}).when(requestSender).send(any(), any(), any(), any(), any(), any());
|
||||
|
||||
service.start();
|
||||
|
||||
verify(requestSender, times(1)).send(any(), any(), any(), any(), any(), any());
|
||||
}
|
||||
|
||||
public void testRemovesRateLimitGroup_AfterStaleDuration() {
|
||||
var now = Instant.now();
|
||||
var clock = mock(Clock.class);
|
||||
when(clock.instant()).thenReturn(now);
|
||||
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
|
||||
var service = new RequestExecutorService(
|
||||
threadPool,
|
||||
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||
null,
|
||||
settings,
|
||||
requestSender,
|
||||
clock,
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
assertThat(service.numberOfRateLimitGroups(), is(1));
|
||||
// the time is moved to after the stale duration, so now we should remove this grouping
|
||||
when(clock.instant()).thenReturn(now.plus(Duration.ofDays(2)));
|
||||
service.removeStaleGroupings();
|
||||
assertThat(service.numberOfRateLimitGroups(), is(0));
|
||||
|
||||
var requestManager2 = RequestManagerTests.createMock(requestSender, "id2");
|
||||
service.execute(requestManager2, new DocumentsOnlyInput(List.of()), null, listener);
|
||||
|
||||
assertThat(service.numberOfRateLimitGroups(), is(1));
|
||||
}
|
||||
|
||||
public void testStartsCleanupThread() {
|
||||
var mockThreadPool = mock(ThreadPool.class);
|
||||
|
||||
when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class));
|
||||
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1));
|
||||
var service = new RequestExecutorService(
|
||||
mockThreadPool,
|
||||
RequestExecutorService.DEFAULT_QUEUE_CREATOR,
|
||||
null,
|
||||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
|
||||
service.shutdown();
|
||||
service.start();
|
||||
|
||||
ArgumentCaptor<TimeValue> argument = ArgumentCaptor.forClass(TimeValue.class);
|
||||
verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), argument.capture(), any());
|
||||
assertThat(argument.getValue(), is(TimeValue.timeValueDays(1)));
|
||||
}
|
||||
|
||||
private Future<?> submitShutdownRequest(
|
||||
|
@ -552,12 +689,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) {
|
||||
return new RequestExecutorService(
|
||||
"test_service",
|
||||
threadPool,
|
||||
startupLatch,
|
||||
createRequestExecutorServiceSettingsEmpty(),
|
||||
new SingleRequestManager(requestSender)
|
||||
);
|
||||
return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
|
|||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.request.RequestTests;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyList;
|
||||
|
@ -21,34 +22,47 @@ import static org.mockito.Mockito.doAnswer;
|
|||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ExecutableRequestCreatorTests {
|
||||
public class RequestManagerTests {
|
||||
public static RequestManager createMock() {
|
||||
var mockCreator = mock(RequestManager.class);
|
||||
when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {});
|
||||
return createMock(mock(RequestSender.class));
|
||||
}
|
||||
|
||||
return mockCreator;
|
||||
public static RequestManager createMock(String inferenceEntityId) {
|
||||
return createMock(mock(RequestSender.class), inferenceEntityId);
|
||||
}
|
||||
|
||||
public static RequestManager createMock(RequestSender requestSender) {
|
||||
return createMock(requestSender, "id");
|
||||
return createMock(requestSender, "id", new RateLimitSettings(1));
|
||||
}
|
||||
|
||||
public static RequestManager createMock(RequestSender requestSender, String modelId) {
|
||||
var mockCreator = mock(RequestManager.class);
|
||||
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) {
|
||||
return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1));
|
||||
}
|
||||
|
||||
public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) {
|
||||
var mockManager = mock(RequestManager.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[5];
|
||||
return (Runnable) () -> requestSender.send(
|
||||
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[4];
|
||||
requestSender.send(
|
||||
mock(Logger.class),
|
||||
RequestTests.mockRequest(modelId),
|
||||
RequestTests.mockRequest(inferenceEntityId),
|
||||
HttpClientContext.create(),
|
||||
() -> false,
|
||||
mock(ResponseHandler.class),
|
||||
listener
|
||||
);
|
||||
}).when(mockCreator).create(any(), anyList(), any(), any(), any(), any());
|
||||
|
||||
return mockCreator;
|
||||
return Void.TYPE;
|
||||
}).when(mockManager).execute(any(), anyList(), any(), any(), any());
|
||||
|
||||
// just return something consistent so the hashing works
|
||||
when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId);
|
||||
|
||||
when(mockManager.rateLimitSettings()).thenReturn(settings);
|
||||
when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId);
|
||||
|
||||
return mockManager;
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.http.client.protocol.HttpClientContext;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class SingleRequestManagerTests extends ESTestCase {
|
||||
public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() {
|
||||
var requestCreator = mock(RequestManager.class);
|
||||
var request = mock(InferenceRequest.class);
|
||||
when(request.getRequestCreator()).thenReturn(requestCreator);
|
||||
|
||||
new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create());
|
||||
verifyNoInteractions(requestCreator);
|
||||
}
|
||||
}
|
|
@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -59,7 +58,7 @@ public class SenderServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
|
||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||
|
@ -67,7 +66,7 @@ public class SenderServiceTests extends ESTestCase {
|
|||
|
||||
listener.actionGet(TIMEOUT);
|
||||
verify(sender, times(1)).start();
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
}
|
||||
|
||||
verify(sender, times(1)).close();
|
||||
|
@ -79,7 +78,7 @@ public class SenderServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) {
|
||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||
|
@ -89,7 +88,7 @@ public class SenderServiceTests extends ESTestCase {
|
|||
service.start(mock(Model.class), listener);
|
||||
listener.actionGet(TIMEOUT);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(2)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -76,7 +76,6 @@ import static org.hamcrest.Matchers.containsString;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -819,7 +818,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -841,7 +840,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,8 @@ public class AzureAiStudioChatCompletionServiceSettingsTests extends ESTestCase
|
|||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, CoreMatchers.is("""
|
||||
{"target":"target_value","provider":"openai","endpoint_type":"token"}"""));
|
||||
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
|
||||
"rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {
|
||||
|
|
|
@ -295,7 +295,7 @@ public class AzureAiStudioEmbeddingsServiceSettingsTests extends ESTestCase {
|
|||
|
||||
assertThat(xContentResult, CoreMatchers.is("""
|
||||
{"target":"target_value","provider":"openai","endpoint_type":"token",""" + """
|
||||
"dimensions":1024,"max_input_tokens":512}"""));
|
||||
"rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}"""));
|
||||
}
|
||||
|
||||
public static HashMap<String, Object> createRequestSettingsMap(
|
||||
|
|
|
@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -594,7 +593,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -616,7 +615,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -46,7 +47,8 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
|
|||
AzureOpenAiServiceFields.API_VERSION,
|
||||
apiVersion
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null)));
|
||||
|
@ -63,18 +65,6 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria
|
|||
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}"""));
|
||||
}
|
||||
|
||||
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = entity.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() {
|
||||
return AzureOpenAiCompletionServiceSettings::new;
|
||||
|
|
|
@ -389,7 +389,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
|
|||
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}"""));
|
||||
}
|
||||
|
||||
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_RateLimit() throws IOException {
|
||||
public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser() throws IOException {
|
||||
var entity = new AzureOpenAiEmbeddingsServiceSettings(
|
||||
"resource",
|
||||
"deployment",
|
||||
|
@ -408,7 +408,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria
|
|||
|
||||
assertThat(xContentResult, is("""
|
||||
{"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
|
||||
"dimensions":1024,"max_input_tokens":512}"""));
|
||||
"dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":1}}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -613,7 +612,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -635,7 +634,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
@ -28,7 +29,8 @@ public class CohereCompletionModelTests extends ESTestCase {
|
|||
"service",
|
||||
new HashMap<>(Map.of()),
|
||||
new HashMap<>(Map.of("model", "overridden model")),
|
||||
null
|
||||
null,
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
|
@ -34,7 +35,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
|
|||
var model = "model";
|
||||
|
||||
var serviceSettings = CohereCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model))
|
||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null)));
|
||||
|
@ -55,7 +57,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
|
|||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute))));
|
||||
|
@ -72,18 +75,6 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin
|
|||
{"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WithFilteredObject_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"url":"url","model_id":"model"}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<CohereCompletionServiceSettings> instanceReader() {
|
||||
return CohereCompletionServiceSettings::new;
|
||||
|
|
|
@ -331,21 +331,6 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
|
|||
"rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var serviceSettings = new CohereEmbeddingsServiceSettings(
|
||||
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)),
|
||||
CohereEmbeddingType.INT8
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
assertThat(xContentResult, is("""
|
||||
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """
|
||||
"embedding_type":"byte"}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<CohereEmbeddingsServiceSettings> instanceReader() {
|
||||
return CohereEmbeddingsServiceSettings::new;
|
||||
|
|
|
@ -51,20 +51,6 @@ public class CohereRerankServiceSettingsTests extends AbstractWireSerializingTes
|
|||
"rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var serviceSettings = new CohereRerankServiceSettings(
|
||||
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3))
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
// TODO we probably shouldn't allow configuring these fields for reranking
|
||||
assertThat(xContentResult, is("""
|
||||
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id"}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<CohereRerankServiceSettings> instanceReader() {
|
||||
return CohereRerankServiceSettings::new;
|
||||
|
|
|
@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.aMapWithSize;
|
|||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.endsWith;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -494,7 +493,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -516,7 +515,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
|
|||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.net.URISyntaxException;
|
||||
|
@ -28,7 +29,8 @@ public class GoogleAiStudioCompletionModelTests extends ESTestCase {
|
|||
"service",
|
||||
new HashMap<>(Map.of("model_id", "model")),
|
||||
new HashMap<>(Map.of()),
|
||||
null
|
||||
null,
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
|
||||
|
@ -31,7 +32,10 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
|
|||
public void testFromMap_Request_CreatesSettingsCorrectly() {
|
||||
var model = "some model";
|
||||
|
||||
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)));
|
||||
var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null)));
|
||||
}
|
||||
|
@ -47,18 +51,6 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe
|
|||
{"model_id":"model","rate_limit":{"requests_per_minute":360}}"""));
|
||||
}
|
||||
|
||||
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var entity = new GoogleAiStudioCompletionServiceSettings("model", null);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = entity.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"model_id":"model"}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<GoogleAiStudioCompletionServiceSettings> instanceReader() {
|
||||
return GoogleAiStudioCompletionServiceSettings::new;
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
|
||||
|
@ -55,7 +56,8 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
|
|||
ServiceFields.SIMILARITY,
|
||||
similarity.toString()
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null)));
|
||||
|
@ -80,23 +82,6 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe
|
|||
}"""));
|
||||
}
|
||||
|
||||
public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = entity.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model_id":"model",
|
||||
"max_input_tokens": 1024,
|
||||
"dimensions": 8,
|
||||
"similarity": "dot_product"
|
||||
}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<GoogleAiStudioEmbeddingsServiceSettings> instanceReader() {
|
||||
return GoogleAiStudioEmbeddingsServiceSettings::new;
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -33,7 +34,6 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
|||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -59,7 +59,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -81,7 +81,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
|||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> secretSettings,
|
||||
String failureMessage
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
@ -57,7 +58,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
var dims = 384;
|
||||
var maxInputTokens = 128;
|
||||
{
|
||||
var serviceSettings = HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)));
|
||||
var serviceSettings = HuggingFaceServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.URL, url)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url)));
|
||||
}
|
||||
{
|
||||
|
@ -73,7 +77,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
assertThat(
|
||||
serviceSettings,
|
||||
|
@ -95,7 +100,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
assertThat(
|
||||
serviceSettings,
|
||||
|
@ -105,7 +111,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
}
|
||||
|
||||
public void testFromMap_MissingUrl_ThrowsError() {
|
||||
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceServiceSettings.fromMap(new HashMap<>()));
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
|
@ -118,7 +127,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
public void testFromMap_EmptyUrl_ThrowsError() {
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")))
|
||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -136,7 +145,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
var url = "https://www.abc^.com";
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)))
|
||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -152,7 +161,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
var similarity = "by_size";
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)))
|
||||
() -> HuggingFaceServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -175,18 +187,6 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
|
|||
{"url":"url","rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var serviceSettings = new HuggingFaceServiceSettings(ServiceUtils.createUri("url"), null, null, null, new RateLimitSettings(3));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"url":"url"}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<HuggingFaceServiceSettings> instanceReader() {
|
||||
return HuggingFaceServiceSettings::new;
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -32,7 +33,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
|||
|
||||
public void testFromMap() {
|
||||
var url = "https://www.abc.com";
|
||||
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)));
|
||||
var serviceSettings = HuggingFaceElserServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings));
|
||||
}
|
||||
|
@ -40,7 +44,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
|||
public void testFromMap_EmptyUrl_ThrowsError() {
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")))
|
||||
() -> HuggingFaceElserServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -55,7 +62,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
|||
}
|
||||
|
||||
public void testFromMap_MissingUrl_ThrowsError() {
|
||||
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>()));
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
|
@ -72,7 +82,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
|||
var url = "https://www.abc^.com";
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)))
|
||||
() -> HuggingFaceElserServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -98,18 +111,6 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin
|
|||
{"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var serviceSettings = new HuggingFaceElserServiceSettings(ServiceUtils.createUri("url"), new RateLimitSettings(3));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"url":"url","max_input_tokens":512}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<HuggingFaceElserServiceSettings> instanceReader() {
|
||||
return HuggingFaceElserServiceSettings::new;
|
||||
|
|
|
@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.containsString;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -393,7 +392,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -415,7 +414,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -98,18 +98,6 @@ public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
|
|||
"rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public void testToFilteredXContent_WritesFilteredValues() throws IOException {
|
||||
var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = entity.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, CoreMatchers.is("""
|
||||
{"model":"model_name","dimensions":1024,"max_input_tokens":512}"""));
|
||||
}
|
||||
|
||||
public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException {
|
||||
var outputBuffer = new BytesStreamOutput();
|
||||
var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3));
|
||||
|
|
|
@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.empty;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -675,7 +674,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender(anyString())).thenReturn(sender);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name");
|
||||
|
||||
|
@ -697,7 +696,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender(anyString());
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
|
||||
|
@ -48,7 +49,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -77,7 +79,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit))
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -101,7 +104,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens
|
||||
)
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertNull(serviceSettings.uri());
|
||||
|
@ -113,7 +117,10 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
public void testFromMap_EmptyUrl_ThrowsError() {
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> OpenAiChatCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")))
|
||||
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
|
@ -132,7 +139,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
var maxInputTokens = 8192;
|
||||
|
||||
var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens))
|
||||
new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertNull(serviceSettings.uri());
|
||||
|
@ -144,7 +152,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model"))
|
||||
new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
|
@ -164,7 +173,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> OpenAiChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model"))
|
||||
new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
|
@ -213,19 +223,6 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial
|
|||
{"model_id":"model","rate_limit":{"requests_per_minute":500}}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException {
|
||||
var serviceSettings = new OpenAiChatCompletionServiceSettings("model", "url", "org", 1024, new RateLimitSettings(2));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
var filteredXContent = serviceSettings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
String xContentResult = org.elasticsearch.common.Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"model_id":"model","url":"url","organization_id":"org",""" + """
|
||||
"max_input_tokens":1024}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<OpenAiChatCompletionServiceSettings> instanceReader() {
|
||||
return OpenAiChatCompletionServiceSettings::new;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue