diff --git a/docs/changelog/107706.yaml b/docs/changelog/107706.yaml new file mode 100644 index 000000000000..76b7f662bf0e --- /dev/null +++ b/docs/changelog/107706.yaml @@ -0,0 +1,5 @@ +pr: 107706 +summary: Add rate limiting support for the Inference API +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index 354cee3f6a99..f805bc0cc92f 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -7,21 +7,17 @@ experimental[] Creates an {infer} endpoint to perform an {infer} task. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in -{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure -OpenAI, Google AI Studio or Hugging Face. For built-in models and models -uploaded though Eland, the {infer} APIs offer an alternative way to use and -manage trained models. However, if you do not plan to use the {infer} APIs to -use these models or if you want to use non-NLP models, use the +{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face. +For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models. +However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. - [discrete] [[put-inference-api-request]] ==== {api-request-title} `PUT /_inference//` - [discrete] [[put-inference-api-prereqs]] ==== {api-prereq-title} @@ -29,7 +25,6 @@ use these models or if you want to use non-NLP models, use the * Requires the `manage_inference` <> (the built-in `inference_admin` role grants this privilege) - [discrete] [[put-inference-api-desc]] ==== {api-description-title} @@ -48,25 +43,23 @@ The following services are available through the {infer} API: * Hugging Face * OpenAI - [discrete] [[put-inference-api-path-params]] ==== {api-path-parms-title} - ``:: (Required, string) The unique identifier of the {infer} endpoint. ``:: (Required, string) -The type of the {infer} task that the model will perform. Available task types: +The type of the {infer} task that the model will perform. +Available task types: * `completion`, * `rerank`, * `sparse_embedding`, * `text_embedding`. - [discrete] [[put-inference-api-request-body]] ==== {api-request-body-title} @@ -78,21 +71,18 @@ Available services: * `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service. * `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service. -* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the -Cohere service. -* `elasticsearch`: specify the `text_embedding` task type to use the E5 -built-in model or text embedding models uploaded by Eland. +* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service. +* `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland. * `elser`: specify the `sparse_embedding` task type to use the ELSER service. * `googleaistudio`: specify the `completion` task to use the Google AI Studio service. -* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face -service. -* `openai`: specify the `completion` or `text_embedding` task type to use the -OpenAI service. +* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service. +* `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service. `service_settings`:: (Required, object) -Settings used to install the {infer} model. These settings are specific to the +Settings used to install the {infer} model. +These settings are specific to the `service` you specified. + .`service_settings` for the `azureaistudio` service @@ -104,11 +94,10 @@ Settings used to install the {infer} model. These settings are specific to the A valid API key of your Azure AI Studio model deployment. This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `target`::: (Required, string) @@ -142,11 +131,13 @@ For "real-time" endpoints which are billed per hour of usage, specify `realtime` By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`. This helps to minimize the number of rate limit errors returned from Azure AI Studio. To modify this, set the `requests_per_minute` setting of this object in your service settings: -``` ++ +[source,text] +---- "rate_limit": { "requests_per_minute": <> } -``` +---- ===== + .`service_settings` for the `azureopenai` service @@ -181,6 +172,22 @@ Your Azure OpenAI deployments can be found though the https://oai.azure.com/[Azu The Azure API version ID to use. We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version]. +`rate_limit`::: +(Optional, object) +The `azureopenai` service sets a default number of requests allowed per minute depending on the task type. +For `text_embedding` it is set to `1440`. +For `completion` it is set to `120`. +This helps to minimize the number of rate limit errors returned from Azure. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- ++ +More information about the rate limits for Azure can be found in the https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits[Quota limits docs] and https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest[How to change the quotas]. ===== + .`service_settings` for the `cohere` service @@ -188,24 +195,24 @@ We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/opena ===== `api_key`::: (Required, string) -A valid API key of your Cohere account. You can find your Cohere API keys or you -can create a new one +A valid API key of your Cohere account. +You can find your Cohere API keys or you can create a new one https://dashboard.cohere.com/api-keys[on the API keys settings page]. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `embedding_type`:: (Optional, string) -Only for `text_embedding`. Specifies the types of embeddings you want to get -back. Defaults to `float`. +Only for `text_embedding`. +Specifies the types of embeddings you want to get back. +Defaults to `float`. Valid values are: - * `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). - * `float`: use it for the default float embeddings. - * `int8`: use it for signed int8 embeddings. +* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). +* `float`: use it for the default float embeddings. +* `int8`: use it for signed int8 embeddings. `model_id`:: (Optional, string) @@ -214,50 +221,68 @@ To review the available `rerank` models, refer to the https://docs.cohere.com/reference/rerank-1[Cohere docs]. To review the available `text_embedding` models, refer to the -https://docs.cohere.com/reference/embed[Cohere docs]. The default value for +https://docs.cohere.com/reference/embed[Cohere docs]. +The default value for `text_embedding` is `embed-english-v2.0`. + +`rate_limit`::: +(Optional, object) +By default, the `cohere` service sets the number of requests allowed per minute to `10000`. +This value is the same for all task types. +This helps to minimize the number of rate limit errors returned from Cohere. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- ++ +More information about Cohere's rate limits can be found in https://docs.cohere.com/docs/going-live#production-key-specifications[Cohere's production key docs]. + ===== + .`service_settings` for the `elasticsearch` service [%collapsible%closed] ===== + `model_id`::: (Required, string) -The name of the model to use for the {infer} task. It can be the -ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or -a text embedding model already +The name of the model to use for the {infer} task. +It can be the ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or a text embedding model already {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. `num_allocations`::: (Required, integer) -The number of model allocations to create. `num_allocations` must not exceed the -number of available processors per node divided by the `num_threads`. +The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`. `num_threads`::: (Required, integer) -The number of threads to use by each model allocation. `num_threads` must not -exceed the number of available processors per node divided by the number of -allocations. Must be a power of 2. Max allowed value is 32. +The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations. +Must be a power of 2. Max allowed value is 32. + ===== + .`service_settings` for the `elser` service [%collapsible%closed] ===== + `num_allocations`::: (Required, integer) -The number of model allocations to create. `num_allocations` must not exceed the -number of available processors per node divided by the `num_threads`. +The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`. `num_threads`::: (Required, integer) -The number of threads to use by each model allocation. `num_threads` must not -exceed the number of available processors per node divided by the number of -allocations. Must be a power of 2. Max allowed value is 32. +The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations. +Must be a power of 2. Max allowed value is 32. + ===== + .`service_settings` for the `googleiastudio` service [%collapsible%closed] ===== + `api_key`::: (Required, string) A valid API key for the Google Gemini API. @@ -274,76 +299,113 @@ This helps to minimize the number of rate limit errors returned from Google AI S To modify this, set the `requests_per_minute` setting of this object in your service settings: + -- -``` +[source,text] +---- "rate_limit": { "requests_per_minute": <> } -``` +---- -- + ===== + .`service_settings` for the `hugging_face` service [%collapsible%closed] ===== + `api_key`::: (Required, string) -A valid access token of your Hugging Face account. You can find your Hugging -Face access tokens or you can create a new one +A valid access token of your Hugging Face account. +You can find your Hugging Face access tokens or you can create a new one https://huggingface.co/settings/tokens[on the settings page]. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `url`::: (Required, string) The URL endpoint to use for the requests. + +`rate_limit`::: +(Optional, object) +By default, the `huggingface` service sets the number of requests allowed per minute to `3000`. +This helps to minimize the number of rate limit errors returned from Hugging Face. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- + ===== + .`service_settings` for the `openai` service [%collapsible%closed] ===== + `api_key`::: (Required, string) -A valid API key of your OpenAI account. You can find your OpenAI API keys in -your OpenAI account under the +A valid API key of your OpenAI account. +You can find your OpenAI API keys in your OpenAI account under the https://platform.openai.com/api-keys[API keys section]. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `model_id`::: (Required, string) -The name of the model to use for the {infer} task. Refer to the +The name of the model to use for the {infer} task. +Refer to the https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation] for the list of available text embedding models. `organization_id`::: (Optional, string) -The unique identifier of your organization. You can find the Organization ID in -your OpenAI account under +The unique identifier of your organization. +You can find the Organization ID in your OpenAI account under https://platform.openai.com/account/organization[**Settings** > **Organizations**]. `url`::: (Optional, string) -The URL endpoint to use for the requests. Can be changed for testing purposes. +The URL endpoint to use for the requests. +Can be changed for testing purposes. Defaults to `https://api.openai.com/v1/embeddings`. +`rate_limit`::: +(Optional, object) +The `openai` service sets a default number of requests allowed per minute depending on the task type. +For `text_embedding` it is set to `3000`. +For `completion` it is set to `500`. +This helps to minimize the number of rate limit errors returned from Azure. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- ++ +More information about the rate limits for OpenAI can be found in your https://platform.openai.com/account/limits[Account limits]. + ===== `task_settings`:: (Optional, object) -Settings to configure the {infer} task. These settings are specific to the +Settings to configure the {infer} task. +These settings are specific to the `` you specified. + .`task_settings` for the `completion` task type [%collapsible%closed] ===== + `do_sample`::: (Optional, float) For the `azureaistudio` service only. @@ -358,8 +420,8 @@ Defaults to 64. `user`::: (Optional, string) -For `openai` service only. Specifies the user issuing the request, which can be -used for abuse detection. +For `openai` service only. +Specifies the user issuing the request, which can be used for abuse detection. `temperature`::: (Optional, float) @@ -378,45 +440,46 @@ Should not be used if `temperature` is specified. .`task_settings` for the `rerank` task type [%collapsible%closed] ===== + `return_documents`:: (Optional, boolean) -For `cohere` service only. Specify whether to return doc text within the -results. +For `cohere` service only. +Specify whether to return doc text within the results. `top_n`:: (Optional, integer) -The number of most relevant documents to return, defaults to the number of the -documents. +The number of most relevant documents to return, defaults to the number of the documents. + ===== + .`task_settings` for the `text_embedding` task type [%collapsible%closed] ===== + `input_type`::: (Optional, string) -For `cohere` service only. Specifies the type of input passed to the model. +For `cohere` service only. +Specifies the type of input passed to the model. Valid values are: - * `classification`: use it for embeddings passed through a text classifier. - * `clusterning`: use it for the embeddings run through a clustering algorithm. - * `ingest`: use it for storing document embeddings in a vector database. - * `search`: use it for storing embeddings of search queries run against a - vector database to find relevant documents. +* `classification`: use it for embeddings passed through a text classifier. +* `clusterning`: use it for the embeddings run through a clustering algorithm. +* `ingest`: use it for storing document embeddings in a vector database. +* `search`: use it for storing embeddings of search queries run against a vector database to find relevant documents. `truncate`::: (Optional, string) -For `cohere` service only. Specifies how the API handles inputs longer than the -maximum token length. Defaults to `END`. Valid values are: - * `NONE`: when the input exceeds the maximum input token length an error is - returned. - * `START`: when the input exceeds the maximum input token length the start of - the input is discarded. - * `END`: when the input exceeds the maximum input token length the end of - the input is discarded. +For `cohere` service only. +Specifies how the API handles inputs longer than the maximum token length. +Defaults to `END`. +Valid values are: +* `NONE`: when the input exceeds the maximum input token length an error is returned. +* `START`: when the input exceeds the maximum input token length the start of the input is discarded. +* `END`: when the input exceeds the maximum input token length the end of the input is discarded. `user`::: (optional, string) -For `openai`, `azureopenai` and `azureaistudio` services only. Specifies the user issuing the -request, which can be used for abuse detection. +For `openai`, `azureopenai` and `azureaistudio` services only. +Specifies the user issuing the request, which can be used for abuse detection. ===== [discrete] @@ -470,7 +533,6 @@ PUT _inference/completion/azure_ai_studio_completion The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer]. - [discrete] [[inference-example-azureopenai]] ===== Azure OpenAI service @@ -519,7 +581,6 @@ The list of chat completion models that you can choose from in your Azure OpenAI * https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models] * https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5] - [discrete] [[inference-example-cohere]] ===== Cohere service @@ -565,7 +626,6 @@ PUT _inference/rerank/cohere-rerank For more examples, also review the https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. - [discrete] [[inference-example-e5]] ===== E5 via the `elasticsearch` service @@ -586,10 +646,9 @@ PUT _inference/text_embedding/my-e5-model } ------------------------------------------------------------ // TEST[skip:TBD] -<1> The `model_id` must be the ID of one of the built-in E5 models. Valid values -are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. For -further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. - +<1> The `model_id` must be the ID of one of the built-in E5 models. +Valid values are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. +For further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. [discrete] [[inference-example-elser]] @@ -597,8 +656,7 @@ further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. The following example shows how to create an {infer} endpoint called `my-elser-model` to perform a `sparse_embedding` task type. -Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more -info. +Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more info. [source,console] ------------------------------------------------------------ @@ -672,16 +730,17 @@ PUT _inference/text_embedding/hugging-face-embeddings } ------------------------------------------------------------ // TEST[skip:TBD] -<1> A valid Hugging Face access token. You can find on the +<1> A valid Hugging Face access token. +You can find on the https://huggingface.co/settings/tokens[settings page of your account]. <2> The {infer} endpoint URL you created on Hugging Face. Create a new {infer} endpoint on -https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an -endpoint URL. Select the model you want to use on the new endpoint creation page -- for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings` -task under the Advanced configuration section. Create the endpoint. Copy the URL -after the endpoint initialization has been finished. +https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an endpoint URL. +Select the model you want to use on the new endpoint creation page - for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings` +task under the Advanced configuration section. +Create the endpoint. +Copy the URL after the endpoint initialization has been finished. [discrete] [[inference-example-hugging-face-supported-models]] @@ -695,7 +754,6 @@ The list of recommended models for the Hugging Face service: * https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base] * https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small] - [discrete] [[inference-example-eland]] ===== Models uploaded by Eland via the elasticsearch service @@ -716,11 +774,9 @@ PUT _inference/text_embedding/my-msmarco-minilm-model } ------------------------------------------------------------ // TEST[skip:TBD] -<1> The `model_id` must be the ID of a text embedding model which has already -been +<1> The `model_id` must be the ID of a text embedding model which has already been {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. - [discrete] [[inference-example-openai]] ===== OpenAI service @@ -756,4 +812,3 @@ PUT _inference/completion/openai-completion } ------------------------------------------------------------ // TEST[skip:TBD] - diff --git a/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java b/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java index df7c47943289..26d93bca6b09 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java +++ b/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java @@ -88,6 +88,13 @@ public class TimeValue implements Comparable { return new TimeValue(days, TimeUnit.DAYS); } + /** + * @return the {@link TimeValue} object that has the least duration. + */ + public static TimeValue min(TimeValue time1, TimeValue time2) { + return time1.compareTo(time2) < 0 ? time1 : time2; + } + /** * @return the unit used for the this time value, see {@link #duration()} */ diff --git a/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java b/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java index b6481db9b995..dd2755ac1f9f 100644 --- a/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java +++ b/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java @@ -17,6 +17,7 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.object.HasToString.hasToString; @@ -231,6 +232,12 @@ public class TimeValueTests extends ESTestCase { assertThat(ex.getMessage(), containsString("duration cannot be negative")); } + public void testMin() { + assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0))); + assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1))); + assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE)); + } + private TimeUnit randomTimeUnitObject() { return randomFrom( TimeUnit.NANOSECONDS, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index 140c08ceef80..81bc90433d34 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor { private final ServiceComponents serviceComponents; public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { + // TODO Batching - accept a class that can handle batching this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java index 63e51d99a8ce..b4815f8f0d1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -36,6 +36,7 @@ public class CohereEmbeddingsAction implements ExecutableAction { model.getServiceSettings().getCommonSettings().uri(), "Cohere embeddings" ); + // TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index deff410aebaa..002fa71b7fb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -37,17 +36,16 @@ public class AzureAiStudioChatCompletionRequestManager extends AzureAiStudioRequ } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createCompletionHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index a2b363151a41..ec5ab2fee6a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -41,17 +40,16 @@ public class AzureAiStudioEmbeddingsRequestManager extends AzureAiStudioRequestM } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createEmbeddingsHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index 2811155f6f35..5206d6c2c23c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -43,16 +42,15 @@ public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManag } @Override - public Runnable create( + public void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index 06152b50822a..e0fcee30e5af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,16 +54,15 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java index abca0ce0d049..a015716b8103 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java @@ -38,11 +38,16 @@ abstract class BaseRequestManager implements RequestManager { @Override public Object rateLimitGrouping() { - return rateLimitGroup; + // It's possible that two inference endpoints have the same information defining the group but have different + // rate limits then they should be in different groups otherwise whoever initially created the group will set + // the rate and the other inference endpoint's rate will be ignored + return new EndpointGrouping(rateLimitGroup, rateLimitSettings); } @Override public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + + private record EndpointGrouping(Object group, RateLimitSettings settings) {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index 255d4a3f3879..8a4b0e45b93f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -46,16 +45,15 @@ public class CohereCompletionRequestManager extends CohereRequestManager { } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereCompletionRequest request = new CohereCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index 0bf1c11285ad..a51910f1d0a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -44,16 +43,15 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager { } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 1778663a194e..1351eec40656 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -44,16 +43,15 @@ public class CohereRerankRequestManager extends CohereRequestManager { } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereRerankRequest request = new CohereRerankRequest(query, input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java index 53f30773cbfe..214eba4ee348 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java @@ -23,7 +23,6 @@ record ExecutableInferenceRequest( RequestSender requestSender, Logger logger, Request request, - HttpClientContext context, ResponseHandler responseHandler, Supplier hasFinished, ActionListener listener @@ -34,7 +33,7 @@ record ExecutableInferenceRequest( var inferenceEntityId = request.createHttpRequest().inferenceEntityId(); try { - requestSender.send(logger, request, context, hasFinished, responseHandler, listener); + requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener); } catch (Exception e) { var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId); logger.warn(errorMessage, e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index eb9baa680446..2b191b046477 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -42,15 +41,14 @@ public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioReques } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java index 15c2825e7d04..6436e0231ab4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -48,17 +47,16 @@ public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioReques } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index 21a758a3db24..d1e309a774ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -15,6 +15,8 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -39,30 +41,28 @@ public class HttpRequestSender implements Sender { private final ServiceComponents serviceComponents; private final HttpClientManager httpClientManager; private final ClusterService clusterService; - private final SingleRequestManager requestManager; + private final RequestSender requestSender; public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) { this.serviceComponents = Objects.requireNonNull(serviceComponents); this.httpClientManager = Objects.requireNonNull(httpClientManager); this.clusterService = Objects.requireNonNull(clusterService); - var requestSender = new RetryingHttpSender( + requestSender = new RetryingHttpSender( this.httpClientManager.getHttpClient(), serviceComponents.throttlerManager(), new RetrySettings(serviceComponents.settings(), clusterService), serviceComponents.threadPool() ); - requestManager = new SingleRequestManager(requestSender); } - public Sender createSender(String serviceName) { + public Sender createSender() { return new HttpRequestSender( - serviceName, serviceComponents.threadPool(), httpClientManager, clusterService, serviceComponents.settings(), - requestManager + requestSender ); } } @@ -71,26 +71,24 @@ public class HttpRequestSender implements Sender { private final ThreadPool threadPool; private final HttpClientManager manager; - private final RequestExecutorService service; + private final RequestExecutor service; private final AtomicBoolean started = new AtomicBoolean(false); private final CountDownLatch startCompleted = new CountDownLatch(1); private HttpRequestSender( - String serviceName, ThreadPool threadPool, HttpClientManager httpClientManager, ClusterService clusterService, Settings settings, - SingleRequestManager requestManager + RequestSender requestSender ) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); service = new RequestExecutorService( - serviceName, threadPool, startCompleted, new RequestExecutorServiceSettings(settings, clusterService), - requestManager + requestSender ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 7c09e0c67c1c..6c8fc446d524 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,26 +54,17 @@ public class HuggingFaceRequestManager extends BaseRequestManager { } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest( - requestSender, - logger, - request, - context, - responseHandler, - hasRequestCompletedFunction, - listener - ); + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int accountHash) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java index 3c711bb79717..6199a75a41a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java @@ -19,9 +19,9 @@ import java.util.function.Supplier; public interface InferenceRequest { /** - * Returns the creator that handles building an executable request based on the input provided. + * Returns the manager that handles building and executing an inference request. */ - RequestManager getRequestCreator(); + RequestManager getRequestManager(); /** * Returns the query associated with this request. Used for Rerank tasks. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index f31a63358170..ab6a1bfb3137 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -51,18 +50,17 @@ public class MistralEmbeddingsRequestManager extends BaseRequestManager { } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int keyHashCode) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java deleted file mode 100644 index 0355880b3f71..000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; - -import java.util.List; -import java.util.function.Supplier; - -class NoopTask implements RejectableTask { - - @Override - public RequestManager getRequestCreator() { - return null; - } - - @Override - public String getQuery() { - return null; - } - - @Override - public List getInput() { - return null; - } - - @Override - public ActionListener getListener() { - return null; - } - - @Override - public boolean hasCompleted() { - return true; - } - - @Override - public Supplier getRequestCompletedFunction() { - return () -> true; - } - - @Override - public void onRejection(Exception e) { - - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 9c6c216c6127..7bc09fd76736 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -43,17 +42,16 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { } @Override - public Runnable create( + public void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createCompletionHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java index 3a0a8fd64a65..41f91d2b89ee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,17 +54,16 @@ public class OpenAiEmbeddingsRequestManager extends OpenAiRequestManager { } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index d5a13c2e0771..38d47aec68eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -17,21 +16,31 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; +import org.elasticsearch.xpack.inference.common.RateLimiter; import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.time.Clock; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; /** * A service for queuing and executing {@link RequestTask}. This class is useful because the @@ -45,7 +54,18 @@ import static org.elasticsearch.core.Strings.format; * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. */ class RequestExecutorService implements RequestExecutor { - private static final AdjustableCapacityBlockingQueue.QueueCreator QUEUE_CREATOR = + + /** + * Provides dependency injection mainly for testing + */ + interface Sleeper { + void sleep(TimeValue sleepTime) throws InterruptedException; + } + + // default for tests + static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration()); + // default for tests + static final AdjustableCapacityBlockingQueue.QueueCreator DEFAULT_QUEUE_CREATOR = new AdjustableCapacityBlockingQueue.QueueCreator<>() { @Override public BlockingQueue create(int capacity) { @@ -65,86 +85,116 @@ class RequestExecutorService implements RequestExecutor { } }; + /** + * Provides dependency injection mainly for testing + */ + interface RateLimiterCreator { + RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit); + } + + // default for testing + static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new; private static final Logger logger = LogManager.getLogger(RequestExecutorService.class); - private final String serviceName; - private final AdjustableCapacityBlockingQueue queue; - private final AtomicBoolean running = new AtomicBoolean(true); - private final CountDownLatch terminationLatch = new CountDownLatch(1); - private final HttpClientContext httpContext; + private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1); + + private final ConcurrentMap rateLimitGroupings = new ConcurrentHashMap<>(); private final ThreadPool threadPool; private final CountDownLatch startupLatch; - private final BlockingQueue controlQueue = new LinkedBlockingQueue<>(); - private final SingleRequestManager requestManager; + private final CountDownLatch terminationLatch = new CountDownLatch(1); + private final RequestSender requestSender; + private final RequestExecutorServiceSettings settings; + private final Clock clock; + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final AdjustableCapacityBlockingQueue.QueueCreator queueCreator; + private final Sleeper sleeper; + private final RateLimiterCreator rateLimiterCreator; + private final AtomicReference cancellableCleanupTask = new AtomicReference<>(); + private final AtomicBoolean started = new AtomicBoolean(false); RequestExecutorService( - String serviceName, ThreadPool threadPool, @Nullable CountDownLatch startupLatch, RequestExecutorServiceSettings settings, - SingleRequestManager requestManager + RequestSender requestSender ) { - this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager); + this( + threadPool, + DEFAULT_QUEUE_CREATOR, + startupLatch, + settings, + requestSender, + Clock.systemUTC(), + DEFAULT_SLEEPER, + DEFAULT_RATE_LIMIT_CREATOR + ); } - /** - * This constructor should only be used directly for testing. - */ RequestExecutorService( - String serviceName, ThreadPool threadPool, - AdjustableCapacityBlockingQueue.QueueCreator createQueue, + AdjustableCapacityBlockingQueue.QueueCreator queueCreator, @Nullable CountDownLatch startupLatch, RequestExecutorServiceSettings settings, - SingleRequestManager requestManager + RequestSender requestSender, + Clock clock, + Sleeper sleeper, + RateLimiterCreator rateLimiterCreator ) { - this.serviceName = Objects.requireNonNull(serviceName); this.threadPool = Objects.requireNonNull(threadPool); - this.httpContext = HttpClientContext.create(); - this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.queueCreator = Objects.requireNonNull(queueCreator); this.startupLatch = startupLatch; - this.requestManager = Objects.requireNonNull(requestManager); - - Objects.requireNonNull(settings); - settings.registerQueueCapacityCallback(this::onCapacityChange); + this.requestSender = Objects.requireNonNull(requestSender); + this.settings = Objects.requireNonNull(settings); + this.clock = Objects.requireNonNull(clock); + this.sleeper = Objects.requireNonNull(sleeper); + this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); } - private void onCapacityChange(int capacity) { - logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity)); - - var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity)); - if (enqueuedCapacityCommand == false) { - logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later."); - } else { - // ensure that the task execution loop wakes up - queue.offer(new NoopTask()); + public void shutdown() { + if (shutdown.compareAndSet(false, true)) { + if (cancellableCleanupTask.get() != null) { + logger.debug(() -> "Stopping clean up thread"); + cancellableCleanupTask.get().cancel(); + } } } - private void updateCapacity(int newCapacity) { - try { - queue.setCapacity(newCapacity); - } catch (Exception e) { - logger.warn( - format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName), - e - ); - } + public boolean isShutdown() { + return shutdown.get(); + } + + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return terminationLatch.await(timeout, unit); + } + + public boolean isTerminated() { + return terminationLatch.getCount() == 0; + } + + public int queueSize() { + return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } /** * Begin servicing tasks. + *

+ * Note: This should only be called once for the life of the object. + *

*/ public void start() { try { + assert started.get() == false : "start() can only be called once"; + started.set(true); + + startCleanupTask(); signalStartInitiated(); - while (running.get()) { + while (isShutdown() == false) { handleTasks(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { - running.set(false); + shutdown(); notifyRequestsOfShutdown(); terminationLatch.countDown(); } @@ -156,108 +206,68 @@ class RequestExecutorService implements RequestExecutor { } } - /** - * Protects the task retrieval logic from an unexpected exception. - * - * @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to - * shut down - */ + private void startCleanupTask() { + assert cancellableCleanupTask.get() == null : "The clean up task can only be set once"; + cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL)); + } + + private Scheduler.Cancellable startCleanupThread(TimeValue interval) { + logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval)); + + return threadPool.scheduleWithFixedDelay(this::removeStaleGroupings, interval, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + // default for testing + void removeStaleGroupings() { + var now = Instant.now(clock); + for (var iter = rateLimitGroupings.values().iterator(); iter.hasNext();) { + var endpoint = iter.next(); + + // if the current time is after the last time the endpoint enqueued a request + allowed stale period then we'll remove it + if (now.isAfter(endpoint.timeOfLastEnqueue().plus(settings.getRateLimitGroupStaleDuration()))) { + endpoint.close(); + iter.remove(); + } + } + } + private void handleTasks() throws InterruptedException { - try { - RejectableTask task = queue.take(); - - var command = controlQueue.poll(); - if (command != null) { - command.run(); - } - - // TODO add logic to complete pending items in the queue before shutting down - if (running.get() == false) { - logger.debug(() -> format("Http executor service [%s] exiting", serviceName)); - rejectTaskBecauseOfShutdown(task); - } else { - executeTask(task); - } - } catch (InterruptedException e) { - throw e; - } catch (Exception e) { - logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e); + var timeToWait = settings.getTaskPollFrequency(); + for (var endpoint : rateLimitGroupings.values()) { + timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); } + + sleeper.sleep(timeToWait); } - private void executeTask(RejectableTask task) { - try { - requestManager.execute(task, httpContext); - } catch (Exception e) { - logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e); - } - } - - private synchronized void notifyRequestsOfShutdown() { + private void notifyRequestsOfShutdown() { assert isShutdown() : "Requests should only be notified if the executor is shutting down"; - try { - List notExecuted = new ArrayList<>(); - queue.drainTo(notExecuted); - - rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown); - } catch (Exception e) { - logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName)); + for (var endpoint : rateLimitGroupings.values()) { + endpoint.notifyRequestsOfShutdown(); } } - private void rejectTaskBecauseOfShutdown(RejectableTask task) { - try { - task.onRejection( - new EsRejectedExecutionException( - format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName), - true - ) - ); - } catch (Exception e) { - logger.warn( - format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName) - ); + // default for testing + Integer remainingQueueCapacity(RequestManager requestManager) { + var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping()); + + if (endpoint == null) { + return null; } + + return endpoint.remainingCapacity(); } - private void rejectTasks(List tasks, Consumer rejectionFunction) { - for (var task : tasks) { - rejectionFunction.accept(task); - } - } - - public int queueSize() { - return queue.size(); - } - - @Override - public void shutdown() { - if (running.compareAndSet(true, false)) { - // if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns - queue.offer(new NoopTask()); - } - } - - @Override - public boolean isShutdown() { - return running.get() == false; - } - - @Override - public boolean isTerminated() { - return terminationLatch.getCount() == 0; - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return terminationLatch.await(timeout, unit); + // default for testing + int numberOfRateLimitGroups() { + return rateLimitGroupings.size(); } /** * Execute the request at some point in the future. * - * @param requestCreator the http request to send + * @param requestManager the http request to send * @param inferenceInputs the inputs to send in the request * @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the * listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}. @@ -265,13 +275,13 @@ class RequestExecutorService implements RequestExecutor { * @param listener an {@link ActionListener} for the response or failure */ public void execute( - RequestManager requestCreator, + RequestManager requestManager, InferenceInputs inferenceInputs, @Nullable TimeValue timeout, ActionListener listener ) { var task = new RequestTask( - requestCreator, + requestManager, inferenceInputs, timeout, threadPool, @@ -280,38 +290,230 @@ class RequestExecutorService implements RequestExecutor { ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) ); - completeExecution(task); - } - - private void completeExecution(RequestTask task) { - if (isShutdown()) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName), - true + var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { + var endpointHandler = new RateLimitingEndpointHandler( + Integer.toString(requestManager.rateLimitGrouping().hashCode()), + queueCreator, + settings, + requestSender, + clock, + requestManager.rateLimitSettings(), + this::isShutdown, + rateLimiterCreator ); - task.onRejection(rejected); - return; - } + endpointHandler.init(); + return endpointHandler; + }); - boolean added = queue.offer(task); - if (added == false) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format("Failed to execute task because the http executor service [%s] queue is full", serviceName), - false - ); - - task.onRejection(rejected); - } else if (isShutdown()) { - // It is possible that a shutdown and notification request occurred after we initially checked for shutdown above - // If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if - // we shut down and if so we'll redo the notification - notifyRequestsOfShutdown(); - } + endpoint.enqueue(task); } - // default for testing - int remainingQueueCapacity() { - return queue.remainingCapacity(); + /** + * Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests. + * This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent + * as soon as a thread is available. + */ + private static class RateLimitingEndpointHandler { + + private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE; + private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO; + private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class); + private static final int ACCUMULATED_TOKENS_LIMIT = 1; + + private final AdjustableCapacityBlockingQueue queue; + private final Supplier isShutdownMethod; + private final RequestSender requestSender; + private final String id; + private final AtomicReference timeOfLastEnqueue = new AtomicReference<>(); + private final Clock clock; + private final RateLimiter rateLimiter; + private final RequestExecutorServiceSettings requestExecutorServiceSettings; + + RateLimitingEndpointHandler( + String id, + AdjustableCapacityBlockingQueue.QueueCreator createQueue, + RequestExecutorServiceSettings settings, + RequestSender requestSender, + Clock clock, + RateLimitSettings rateLimitSettings, + Supplier isShutdownMethod, + RateLimiterCreator rateLimiterCreator + ) { + this.requestExecutorServiceSettings = Objects.requireNonNull(settings); + this.id = Objects.requireNonNull(id); + this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.requestSender = Objects.requireNonNull(requestSender); + this.clock = Objects.requireNonNull(clock); + this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod); + + Objects.requireNonNull(rateLimitSettings); + Objects.requireNonNull(rateLimiterCreator); + rateLimiter = rateLimiterCreator.create( + ACCUMULATED_TOKENS_LIMIT, + rateLimitSettings.requestsPerTimeUnit(), + rateLimitSettings.timeUnit() + ); + + } + + public void init() { + requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange); + } + + private void onCapacityChange(int capacity) { + logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity)); + + try { + queue.setCapacity(capacity); + } catch (Exception e) { + logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e); + } + } + + public int queueSize() { + return queue.size(); + } + + public boolean isShutdown() { + return isShutdownMethod.get(); + } + + public Instant timeOfLastEnqueue() { + return timeOfLastEnqueue.get(); + } + + public synchronized TimeValue executeEnqueuedTask() { + try { + return executeEnqueuedTaskInternal(); + } catch (Exception e) { + logger.warn(format("Executor service grouping [%s] failed to execute request", id), e); + // we tried to do some work but failed, so we'll say we did something to try looking for more work + return EXECUTED_A_TASK; + } + } + + private TimeValue executeEnqueuedTaskInternal() { + var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); + if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { + return timeBeforeAvailableToken; + } + + var task = queue.poll(); + + // TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks + // So we'll need to check for null and call a helper method executePreparedTasks() + + if (shouldExecuteTask(task) == false) { + return NO_TASKS_AVAILABLE; + } + + // We should never have to wait because we checked above + var reserveRes = rateLimiter.reserve(1); + assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have"; + + task.getRequestManager() + .execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener()); + return EXECUTED_A_TASK; + } + + private static boolean shouldExecuteTask(RejectableTask task) { + return task != null && isNoopRequest(task) == false && task.hasCompleted() == false; + } + + private static boolean isNoopRequest(InferenceRequest inferenceRequest) { + return inferenceRequest.getRequestManager() == null + || inferenceRequest.getInput() == null + || inferenceRequest.getListener() == null; + } + + private static boolean shouldExecuteImmediately(TimeValue delay) { + return delay.duration() == 0; + } + + public void enqueue(RequestTask task) { + timeOfLastEnqueue.set(Instant.now(clock)); + + if (isShutdown()) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format( + "Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown", + task.getRequestManager().inferenceEntityId(), + id + ), + true + ); + + task.onRejection(rejected); + return; + } + + var addedToQueue = queue.offer(task); + + if (addedToQueue == false) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format( + "Failed to execute task for inference id [%s] because the request service [%s] queue is full", + task.getRequestManager().inferenceEntityId(), + id + ), + false + ); + + task.onRejection(rejected); + } else if (isShutdown()) { + notifyRequestsOfShutdown(); + } + } + + public synchronized void notifyRequestsOfShutdown() { + assert isShutdown() : "Requests should only be notified if the executor is shutting down"; + + try { + List notExecuted = new ArrayList<>(); + queue.drainTo(notExecuted); + + rejectTasks(notExecuted); + } catch (Exception e) { + logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id)); + } + } + + private void rejectTasks(List tasks) { + for (var task : tasks) { + rejectTaskForShutdown(task); + } + } + + private void rejectTaskForShutdown(RejectableTask task) { + try { + task.onRejection( + new EsRejectedExecutionException( + format( + "Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request", + id, + task.getRequestManager().inferenceEntityId() + ), + true + ) + ); + } catch (Exception e) { + logger.warn( + format( + "Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown", + task.getRequestManager().inferenceEntityId(), + id + ) + ); + } + } + + public int remainingCapacity() { + return queue.remainingCapacity(); + } + + public void close() { + requestExecutorServiceSettings.deregisterQueueCapacityCallback(id); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java index 86825035f2d0..616ef7a40068 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java @@ -10,9 +10,12 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; -import java.util.ArrayList; +import java.time.Duration; import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; public class RequestExecutorServiceSettings { @@ -29,37 +32,108 @@ public class RequestExecutorServiceSettings { Setting.Property.Dynamic ); + private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50); + /** + * Defines how often all the rate limit groups are polled for tasks. Setting this to very low number could result + * in a busy loop if there are no tasks available to handle. + */ + static final Setting TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.task_poll_frequency", + DEFAULT_TASK_POLL_FREQUENCY_TIME, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1); + /** + * Defines how often a thread will check for rate limit groups that are stale. + */ + static final Setting RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.rate_limit_group_cleanup_interval", + DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION = TimeValue.timeValueDays(10); + /** + * Defines the amount of time it takes to classify a rate limit group as stale. Once it is classified as stale, + * it can be removed when the cleanup thread executes. + */ + static final Setting RATE_LIMIT_GROUP_STALE_DURATION_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.rate_limit_group_stale_duration", + DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static List> getSettingsDefinitions() { - return List.of(TASK_QUEUE_CAPACITY_SETTING); + return List.of( + TASK_QUEUE_CAPACITY_SETTING, + TASK_POLL_FREQUENCY_SETTING, + RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING, + RATE_LIMIT_GROUP_STALE_DURATION_SETTING + ); } private volatile int queueCapacity; - private final List> queueCapacityCallbacks = new ArrayList>(); + private volatile TimeValue taskPollFrequency; + private volatile Duration rateLimitGroupStaleDuration; + private final ConcurrentMap> queueCapacityCallbacks = new ConcurrentHashMap<>(); public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) { queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings); + taskPollFrequency = TASK_POLL_FREQUENCY_SETTING.get(settings); + setRateLimitGroupStaleDuration(RATE_LIMIT_GROUP_STALE_DURATION_SETTING.get(settings)); addSettingsUpdateConsumers(clusterService); } private void addSettingsUpdateConsumers(ClusterService clusterService) { clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity); + clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(RATE_LIMIT_GROUP_STALE_DURATION_SETTING, this::setRateLimitGroupStaleDuration); } // default for testing void setQueueCapacity(int queueCapacity) { this.queueCapacity = queueCapacity; - for (var callback : queueCapacityCallbacks) { + for (var callback : queueCapacityCallbacks.values()) { callback.accept(queueCapacity); } } - void registerQueueCapacityCallback(Consumer onChangeCapacityCallback) { - queueCapacityCallbacks.add(onChangeCapacityCallback); + private void setTaskPollFrequency(TimeValue taskPollFrequency) { + this.taskPollFrequency = taskPollFrequency; + } + + private void setRateLimitGroupStaleDuration(TimeValue staleDuration) { + rateLimitGroupStaleDuration = toDuration(staleDuration); + } + + private static Duration toDuration(TimeValue timeValue) { + return Duration.of(timeValue.duration(), timeValue.timeUnit().toChronoUnit()); + } + + void registerQueueCapacityCallback(String id, Consumer onChangeCapacityCallback) { + queueCapacityCallbacks.put(id, onChangeCapacityCallback); + } + + void deregisterQueueCapacityCallback(String id) { + queueCapacityCallbacks.remove(id); } int getQueueCapacity() { return queueCapacity; } + + TimeValue getTaskPollFrequency() { + return taskPollFrequency; + } + + Duration getRateLimitGroupStaleDuration() { + return rateLimitGroupStaleDuration; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java index 7d3cca596f1d..79ef1b56ad23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; @@ -21,14 +20,17 @@ import java.util.function.Supplier; * A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service. */ public interface RequestManager extends RateLimitable { - Runnable create( + void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ); + // TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual + // managers to implement their own batching + // executePreparedRequest() which will execute all prepared requests aka sends the batch + String inferenceEntityId(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index 738592464232..7a5f48241228 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -111,7 +111,7 @@ class RequestTask implements RejectableTask { } @Override - public RequestManager getRequestCreator() { + public RequestManager getRequestManager() { return requestCreator; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java deleted file mode 100644 index 494c77964080..000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.http.client.protocol.HttpClientContext; -import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; - -import java.util.Objects; - -/** - * Handles executing a single inference request at a time. - */ -public class SingleRequestManager { - - protected RetryingHttpSender requestSender; - - public SingleRequestManager(RetryingHttpSender requestSender) { - this.requestSender = Objects.requireNonNull(requestSender); - } - - public void execute(InferenceRequest inferenceRequest, HttpClientContext context) { - if (isNoopRequest(inferenceRequest) || inferenceRequest.hasCompleted()) { - return; - } - - inferenceRequest.getRequestCreator() - .create( - inferenceRequest.getQuery(), - inferenceRequest.getInput(), - requestSender, - inferenceRequest.getRequestCompletedFunction(), - context, - inferenceRequest.getListener() - ) - .run(); - } - - private static boolean isNoopRequest(InferenceRequest inferenceRequest) { - return inferenceRequest.getRequestCreator() == null - || inferenceRequest.getInput() == null - || inferenceRequest.getListener() == null; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 24c0ab2cd893..1c64f505402d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService { public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { Objects.requireNonNull(factory); - sender = factory.createSender(name()); + sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index c488eac42240..f30773962854 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -56,7 +56,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completio public class AzureAiStudioService extends SenderService { - private static final String NAME = "azureaistudio"; + static final String NAME = "azureaistudio"; public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java index 10c57e19b640..03034ae70c2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java @@ -44,7 +44,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec ConfigurationParseContext context ) { String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AzureAiStudioService.NAME, + context + ); AzureAiStudioEndpointType endpointType = extractRequiredEnum( map, ENDPOINT_TYPE_FIELD, @@ -118,13 +124,13 @@ public abstract class AzureAiStudioServiceSettings extends FilteredXContentObjec protected void addXContentFields(XContentBuilder builder, Params params) throws IOException { this.addExposedXContentFields(builder, params); - rateLimitSettings.toXContent(builder, params); } protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException { builder.field(TARGET_FIELD, this.target); builder.field(PROVIDER_FIELD, this.provider); builder.field(ENDPOINT_TYPE_FIELD, this.endpointType); + rateLimitSettings.toXContent(builder, params); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e0e48ab20a86..26bf6f1648d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -135,7 +135,15 @@ public class AzureOpenAiService extends SenderService { ); } case COMPLETION -> { - return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + return new AzureOpenAiCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); } default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java index 05cb66345354..c4146b2ba2d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; @@ -37,13 +38,14 @@ public class AzureOpenAiCompletionModel extends AzureOpenAiModel { String service, Map serviceSettings, Map taskSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings), + AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context), AzureOpenAiCompletionTaskSettings.fromMap(taskSettings), AzureOpenAiSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java index ba503b2bbdc4..92dc461d9008 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java @@ -17,7 +17,9 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -55,10 +57,10 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120); - public static AzureOpenAiCompletionServiceSettings fromMap(Map map) { + public static AzureOpenAiCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); - var settings = fromMap(map, validationException); + var settings = fromMap(map, validationException, context); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -69,12 +71,19 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( Map map, - ValidationException validationException + ValidationException validationException, + ConfigurationParseContext context ) { String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AzureOpenAiService.NAME, + context + ); return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); } @@ -137,7 +146,6 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -148,6 +156,7 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject builder.field(RESOURCE_NAME, resourceName); builder.field(DEPLOYMENT_ID, deploymentId); builder.field(API_VERSION, apiVersion); + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java index 33bb0fdb07c5..1c426815a83c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -90,7 +91,13 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AzureOpenAiService.NAME, + context + ); Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); @@ -245,8 +252,6 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - - rateLimitSettings.toXContent(builder, params); builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); builder.endObject(); @@ -268,6 +273,7 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject if (similarity != null) { builder.field(SIMILARITY, similarity); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 11dbf673ab7b..4c673026d7ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -51,6 +51,11 @@ import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFie public class CohereService extends SenderService { public static final String NAME = "cohere"; + // TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to + // the Cohere*RequestManager via the CohereActionCreator class + // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated + // on every request + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } @@ -131,7 +136,15 @@ public class CohereService extends SenderService { context ); case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); - case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + case COMPLETION -> new CohereCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index b23f6f188d8c..d477a8c5a5f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -58,7 +58,13 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CohereService.NAME, + context + ); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -173,10 +179,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser } public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { - toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); - - return builder; + return toXContentFragmentOfExposedFields(builder, params); } @Override @@ -196,6 +199,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser if (modelId != null) { builder.field(MODEL_ID, modelId); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java index 761081d4d723..bec4f5a0b5c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -30,13 +31,14 @@ public class CohereCompletionModel extends CohereModel { String service, Map serviceSettings, Map taskSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( modelId, taskType, service, - CohereCompletionServiceSettings.fromMap(serviceSettings), + CohereCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java index 2a22f6333f1a..ba9e81b461f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java @@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -39,12 +41,18 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl // 10K requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); - public static CohereCompletionServiceSettings fromMap(Map map) { + public static CohereCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CohereService.NAME, + context + ); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { @@ -94,7 +102,6 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -127,6 +134,7 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl if (modelId != null) { builder.field(MODEL_ID, modelId); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index f8720448b0f4..cfa856649514 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -108,7 +108,8 @@ public class GoogleAiStudioService extends SenderService { NAME, serviceSettings, taskSettings, - secretSettings + secretSettings, + context ); case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel( inferenceEntityId, @@ -116,7 +117,8 @@ public class GoogleAiStudioService extends SenderService { NAME, serviceSettings, taskSettings, - secretSettings + secretSettings, + context ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java index eafb0c372202..8fa2ac014871 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -37,13 +38,14 @@ public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel { String service, Map serviceSettings, Map taskSettings, - Map secrets + Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings), + GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java index f8f343be8eb4..7c0b812ee213 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java @@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -40,11 +42,17 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); - public static GoogleAiStudioCompletionServiceSettings fromMap(Map map) { + public static GoogleAiStudioCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + GoogleAiStudioService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -82,7 +90,6 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -107,6 +114,7 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj @Override protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java index ad106797de51..af19e26f3e97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -37,13 +38,14 @@ public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel { String service, Map serviceSettings, Map taskSettings, - Map secrets + Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings), + GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java index 07d07dc533f0..7608f48d0638 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java @@ -18,7 +18,9 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -47,7 +49,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); - public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map map) { + public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -59,7 +61,13 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj ); SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + GoogleAiStudioService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -134,7 +142,6 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -174,6 +181,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj if (similarity != null) { builder.field(SIMILARITY, similarity); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index ebb6c10207f4..ef034816f762 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -62,7 +63,8 @@ public abstract class HuggingFaceBaseService extends SenderService { taskType, serviceSettingsMap, serviceSettingsMap, - TaskType.unsupportedTaskTypeErrorMsg(taskType, name()) + TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), + ConfigurationParseContext.REQUEST ); throwIfNotEmptyMap(config, name()); @@ -89,7 +91,8 @@ public abstract class HuggingFaceBaseService extends SenderService { taskType, serviceSettingsMap, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, name()) + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT ); } @@ -97,7 +100,14 @@ public abstract class HuggingFaceBaseService extends SenderService { public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - return createModel(inferenceEntityId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(inferenceEntityId, name())); + return createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT + ); } protected abstract HuggingFaceModel createModel( @@ -105,7 +115,8 @@ public abstract class HuggingFaceBaseService extends SenderService { TaskType taskType, Map serviceSettings, Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ); @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d8c383d2b4a6..c0438b3759a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; @@ -36,11 +37,19 @@ public class HuggingFaceService extends HuggingFaceBaseService { TaskType taskType, Map serviceSettings, @Nullable Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ) { return switch (taskType) { - case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); + case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); + case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index af2c433663ac..fc31b1e518dd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -43,14 +44,20 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement // 3000 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); - public static HuggingFaceServiceSettings fromMap(Map map) { + public static HuggingFaceServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var uri = extractUri(map, URL, validationException); SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -119,7 +126,6 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; } @@ -136,6 +142,7 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java index 9010571ea2e5..8132089d8dc9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -24,13 +25,14 @@ public class HuggingFaceElserModel extends HuggingFaceModel { TaskType taskType, String service, Map serviceSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - HuggingFaceElserServiceSettings.fromMap(serviceSettings), + HuggingFaceElserServiceSettings.fromMap(serviceSettings, context), DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 2587b2737e16..d3099e96ee7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; @@ -38,10 +39,11 @@ public class HuggingFaceElserService extends HuggingFaceBaseService { TaskType taskType, Map serviceSettings, @Nullable Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ) { return switch (taskType) { - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); + case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java index 1f337de450ef..8b4bd61649de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java @@ -15,7 +15,9 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -40,10 +42,16 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject // 3000 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); - public static HuggingFaceElserServiceSettings fromMap(Map map) { + public static HuggingFaceElserServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var uri = extractUri(map, URL, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -93,7 +101,6 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -103,6 +110,7 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { builder.field(URL, uri.toString()); builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT); + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java index 1cee26558b49..fedd6380d035 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -25,13 +26,14 @@ public class HuggingFaceEmbeddingsModel extends HuggingFaceModel { TaskType taskType, String service, Map serviceSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - HuggingFaceServiceSettings.fromMap(serviceSettings), + HuggingFaceServiceSettings.fromMap(serviceSettings, context), DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index d2ea8ccbd18b..62d06a4e0029 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -59,7 +60,13 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp ModelConfigurations.SERVICE_SETTINGS, validationException ); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + MistralService.NAME, + context + ); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); if (validationException.validationErrors().isEmpty() == false) { @@ -141,7 +148,6 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); this.toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; } @@ -159,6 +165,7 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp if (this.maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, this.maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 84dfac890367..04b6ae94d6b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -138,7 +138,8 @@ public class OpenAiService extends SenderService { NAME, serviceSettings, taskSettings, - secretSettings + secretSettings, + context ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index b1b670c0911f..7ca93684bc68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -35,13 +36,14 @@ public class OpenAiChatCompletionModel extends OpenAiModel { String service, Map serviceSettings, Map taskSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - OpenAiChatCompletionServiceSettings.fromMap(serviceSettings), + OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context), OpenAiChatCompletionTaskSettings.fromMap(taskSettings), DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java index 5105bb59e048..04f77da1b146 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java @@ -16,7 +16,9 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -47,7 +49,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject // 500 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); - public static OpenAiChatCompletionServiceSettings fromMap(Map map) { + public static OpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -58,7 +60,13 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenAiService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -142,7 +150,6 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -163,6 +170,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index fc479009d333..080251bf1ba3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -66,7 +67,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl // passed at that time and never throw. ValidationException validationException = new ValidationException(); - var commonFields = fromMap(map, validationException); + var commonFields = fromMap(map, validationException, ConfigurationParseContext.PERSISTENT); Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); if (dimensionsSetByUser == null) { @@ -80,7 +81,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map map) { ValidationException validationException = new ValidationException(); - var commonFields = fromMap(map, validationException); + var commonFields = fromMap(map, validationException, ConfigurationParseContext.REQUEST); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -89,7 +90,11 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl return new OpenAiEmbeddingsServiceSettings(commonFields, commonFields.dimensions != null); } - private static CommonFields fromMap(Map map, ValidationException validationException) { + private static CommonFields fromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -98,7 +103,13 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl Integer dims = removeAsType(map, DIMENSIONS, Integer.class); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenAiService.NAME, + context + ); return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings); } @@ -258,7 +269,6 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); if (dimensionsSetByUser != null) { builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); @@ -286,6 +296,7 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java index cfc375a525dd..f593ca4e0c60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; import java.util.Map; @@ -21,19 +22,29 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; public class RateLimitSettings implements Writeable, ToXContentFragment { - public static final String FIELD_NAME = "rate_limit"; public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute"; private final long requestsPerTimeUnit; private final TimeUnit timeUnit; - public static RateLimitSettings of(Map map, RateLimitSettings defaultValue, ValidationException validationException) { + public static RateLimitSettings of( + Map map, + RateLimitSettings defaultValue, + ValidationException validationException, + String serviceName, + ConfigurationParseContext context + ) { Map settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME); var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException); + if (ConfigurationParseContext.isRequestContext(context)) { + throwIfNotEmptyMap(settings, serviceName); + } + return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 88d408d309a7..8792234102a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -92,7 +93,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase { TruncatorTests.createTruncator() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson)); @@ -141,7 +142,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase { TruncatorTests.createTruncator() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 0a2a00143b20..72124a622125 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; @@ -82,7 +83,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -132,7 +133,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -183,7 +184,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -237,7 +238,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // note - there is no complete documentation on Azure's error messages @@ -313,7 +314,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // note - there is no complete documentation on Azure's error messages @@ -389,7 +390,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -440,7 +441,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -498,7 +499,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -554,7 +555,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase { // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // "choices" missing diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java index 96127841c17a..7d5261640240 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.hamcrest.Matchers.hasSize; @@ -77,7 +78,7 @@ public class AzureOpenAiCompletionActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java index 89cc84732179..4cc7b7c0d9cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; @@ -81,7 +82,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 9b0371ad51f8..9ec34e7d8e5c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -73,7 +74,7 @@ public class CohereActionCreatorTests extends ESTestCase { public void testCreate_CohereEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -154,7 +155,7 @@ public class CohereActionCreatorTests extends ESTestCase { public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -214,7 +215,7 @@ public class CohereActionCreatorTests extends ESTestCase { public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index 12c3d132d124..0a604980f6c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -77,7 +77,7 @@ public class CohereCompletionActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -138,7 +138,7 @@ public class CohereCompletionActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -290,7 +290,7 @@ public class CohereCompletionActionTests extends ESTestCase { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index dbc97fa2e13d..9cf6de27b93b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -81,7 +81,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -162,7 +162,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java index 09ef5351eb1f..9dd465e0276f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -74,7 +74,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -206,7 +206,7 @@ public class GoogleAiStudioCompletionActionTests extends ESTestCase { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java index a55b3c5f5030..7e98b9b31f6e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java @@ -79,7 +79,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase { var input = "input"; var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index fceea8810f6c..b3ec565b3146 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.contains; @@ -75,7 +76,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -131,7 +132,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -187,7 +188,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -239,7 +240,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]] @@ -292,7 +293,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJsonContentTooLarge = """ @@ -357,7 +358,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index 496238eaad0e..b6d7eb673b7f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -74,7 +75,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testCreate_OpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -127,7 +128,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -179,7 +180,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -238,7 +239,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -292,7 +293,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testCreate_OpenAiChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -355,7 +356,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -417,7 +418,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -486,7 +487,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -552,7 +553,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); var contentTooLargeErrorMessage = @@ -635,7 +636,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); var contentTooLargeErrorMessage = @@ -718,7 +719,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index 914ff12db259..42b062667f77 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -80,7 +81,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -234,7 +235,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 15b7417912ef..03c0b4d146b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -79,7 +79,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java new file mode 100644 index 000000000000..03838896b879 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.Mockito.mock; + +public class BaseRequestManagerTests extends ESTestCase { + public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() { + int val1 = 1; + int val2 = 1; + + var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + assertThat(manager1.rateLimitGrouping(), is(manager2.rateLimitGrouping())); + } + + public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() { + int val1 = 1; + + var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping())); + } + + public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() { + int val1 = 1; + + var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping())); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 368745b31088..2b8b5f178b3d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -79,7 +79,7 @@ public class HttpRequestSenderTests extends ESTestCase { public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { var senderFactory = createSenderFactory(clientManager, threadRef); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -135,11 +135,11 @@ public class HttpRequestSenderTests extends ESTestCase { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( AssertionError.class, - () -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) + () -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) ); assertThat(thrownException.getMessage(), is("call start() before sending a request")); } @@ -155,17 +155,12 @@ public class HttpRequestSenderTests extends ESTestCase { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { assertThat(sender, instanceOf(HttpRequestSender.class)); sender.start(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -186,16 +181,11 @@ public class HttpRequestSenderTests extends ESTestCase { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -220,6 +210,7 @@ public class HttpRequestSenderTests extends ESTestCase { when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService); when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class)); + when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class)); return new HttpRequestSender.Factory( ServiceComponentsTests.createWithEmptySettings(mockThreadPool), @@ -248,7 +239,7 @@ public class HttpRequestSenderTests extends ESTestCase { ); } - public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) { - return factory.createSender(serviceName); + public static Sender createSender(HttpRequestSender.Factory factory) { + return factory.createSender(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java index c0c0bdd49f61..489b502c0411 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; import static org.elasticsearch.xpack.inference.Utils.mockClusterService; @@ -18,12 +19,23 @@ public class RequestExecutorServiceSettingsTests { } public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) { + return createRequestExecutorServiceSettings(queueCapacity, null); + } + + public static RequestExecutorServiceSettings createRequestExecutorServiceSettings( + @Nullable Integer queueCapacity, + @Nullable TimeValue staleDuration + ) { var settingsBuilder = Settings.builder(); if (queueCapacity != null) { settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity); } + if (staleDuration != null) { + settingsBuilder.put(RequestExecutorServiceSettings.RATE_LIMIT_GROUP_STALE_DURATION_SETTING.getKey(), staleDuration); + } + return createRequestExecutorServiceSettings(settingsBuilder.build()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index ff88ba221d98..9a45e1000764 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -18,13 +18,19 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.RateLimiter; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; @@ -42,10 +48,13 @@ import static org.elasticsearch.xpack.inference.external.http.sender.RequestExec import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; public class RequestExecutorServiceTests extends ESTestCase { @@ -70,7 +79,7 @@ public class RequestExecutorServiceTests extends ESTestCase { public void testQueueSize_IsOne() { var service = createRequestExecutorServiceWithMocks(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); } @@ -92,7 +101,20 @@ public class RequestExecutorServiceTests extends ESTestCase { assertTrue(service.isTerminated()); } - public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { + public void testCallingStartTwice_ThrowsAssertionException() throws InterruptedException { + var latch = new CountDownLatch(1); + var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class)); + + service.shutdown(); + service.start(); + latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + + assertTrue(service.isTerminated()); + var exception = expectThrows(AssertionError.class, service::start); + assertThat(exception.getMessage(), is("start() can only be called once")); + } + + public void testIsTerminated_AfterStopFromSeparateThread() { var waitToShutdown = new CountDownLatch(1); var waitToReturnFromSend = new CountDownLatch(1); @@ -127,41 +149,48 @@ public class RequestExecutorServiceTests extends ESTestCase { assertTrue(service.isTerminated()); } - public void testSend_AfterShutdown_Throws() { + public void testExecute_AfterShutdown_Throws() { var service = createRequestExecutorServiceWithMocks(); service.shutdown(); + var requestManager = RequestManagerTests.createMock("id"); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to enqueue task because the http executor service [test_service] has already shutdown") + is( + Strings.format( + "Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertTrue(thrownException.isExecutorShutdown()); } - public void testSend_Throws_WhenQueueIsFull() { - var service = new RequestExecutorService( - "test_service", - threadPool, - null, - createRequestExecutorServiceSettings(1), - new SingleRequestManager(mock(RetryingHttpSender.class)) - ); + public void testExecute_Throws_WhenQueueIsFull() { + var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class)); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + + var requestManager = RequestManagerTests.createMock("id"); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") + is( + Strings.format( + "Failed to execute task for inference id [id] because the request service [%s] queue is full", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertFalse(thrownException.isExecutorShutdown()); } @@ -203,16 +232,11 @@ public class RequestExecutorServiceTests extends ESTestCase { assertTrue(service.isShutdown()); } - public void testSend_CallsOnFailure_WhenRequestTimesOut() { + public void testExecute_CallsOnFailure_WhenRequestTimesOut() { var service = createRequestExecutorServiceWithMocks(); var listener = new PlainActionFuture(); - service.execute( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -222,7 +246,7 @@ public class RequestExecutorServiceTests extends ESTestCase { ); } - public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException { + public void testExecute_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException { var headerKey = "not empty"; var headerValue = "value"; @@ -270,7 +294,7 @@ public class RequestExecutorServiceTests extends ESTestCase { } }; - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); @@ -280,11 +304,12 @@ public class RequestExecutorServiceTests extends ESTestCase { finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); } - public void testSend_NotifiesTasksOfShutdown() { + public void testExecute_NotifiesTasksOfShutdown() { var service = createRequestExecutorServiceWithMocks(); + var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id"); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); service.shutdown(); service.start(); @@ -293,47 +318,62 @@ public class RequestExecutorServiceTests extends ESTestCase { assertThat( thrownException.getMessage(), - is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + is( + Strings.format( + "Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertTrue(thrownException.isExecutorShutdown()); assertTrue(service.isTerminated()); } - public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { + public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { @SuppressWarnings("unchecked") BlockingQueue queue = mock(LinkedBlockingQueue.class); + var requestSender = mock(RetryingHttpSender.class); + var service = new RequestExecutorService( - getTestName(), threadPool, mockQueueCreator(queue), null, createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(mock(RetryingHttpSender.class)) + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); - when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { + PlainActionFuture listener = new PlainActionFuture<>(); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { service.shutdown(); return null; }); service.start(); assertTrue(service.isTerminated()); - verify(queue, times(2)).take(); } - public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception { + public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception { @SuppressWarnings("unchecked") BlockingQueue queue = mock(LinkedBlockingQueue.class); - when(queue.take()).thenThrow(new InterruptedException("failed")); + var sleeper = mock(RequestExecutorService.Sleeper.class); + doThrow(new InterruptedException("failed")).when(sleeper).sleep(any()); var service = new RequestExecutorService( - getTestName(), threadPool, mockQueueCreator(queue), null, createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(mock(RetryingHttpSender.class)) + mock(RetryingHttpSender.class), + Clock.systemUTC(), + sleeper, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); Future executorTermination = threadPool.generic().submit(() -> { @@ -347,66 +387,30 @@ public class RequestExecutorServiceTests extends ESTestCase { executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - verify(queue, times(1)).take(); - } - - public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception { - var mockTask = mock(RejectableTask.class); - @SuppressWarnings("unchecked") - BlockingQueue queue = mock(LinkedBlockingQueue.class); - - var service = new RequestExecutorService( - "test_service", - threadPool, - mockQueueCreator(queue), - null, - createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(mock(RetryingHttpSender.class)) - ); - - doAnswer(invocation -> { - service.shutdown(); - return mockTask; - }).doReturn(new NoopTask()).when(queue).take(); - - service.start(); - - assertTrue(service.isTerminated()); - verify(queue, times(1)).take(); - - ArgumentCaptor argument = ArgumentCaptor.forClass(Exception.class); - verify(mockTask, times(1)).onRejection(argument.capture()); - assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class)); - assertThat( - argument.getValue().getMessage(), - is("Failed to send request, queue service [test_service] has shutdown prior to executing request") - ); - - var rejectionException = (EsRejectedExecutionException) argument.getValue(); - assertTrue(rejectionException.isExecutorShutdown()); } public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException { var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(1); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") + is( + Strings.format( + "Failed to execute task for inference id [id] because the request service [%s] queue is full", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); settings.setQueueCapacity(2); @@ -426,7 +430,7 @@ public class RequestExecutorServiceTests extends ESTestCase { executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(2)); + assertThat(service.remainingQueueCapacity(requestManager), is(2)); } public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException, @@ -434,23 +438,24 @@ public class RequestExecutorServiceTests extends ESTestCase { var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(3); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), + RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>() ); service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), + RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>() ); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); assertThat(service.queueSize(), is(3)); settings.setQueueCapacity(1); @@ -470,7 +475,7 @@ public class RequestExecutorServiceTests extends ESTestCase { executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(1)); + assertThat(service.remainingQueueCapacity(requestManager), is(1)); assertThat(service.queueSize(), is(0)); var thrownException = expectThrows( @@ -479,7 +484,12 @@ public class RequestExecutorServiceTests extends ESTestCase { ); assertThat( thrownException.getMessage(), - is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + is( + Strings.format( + "Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertTrue(thrownException.isExecutorShutdown()); } @@ -489,23 +499,24 @@ public class RequestExecutorServiceTests extends ESTestCase { var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(1); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); + var requestManager = RequestManagerTests.createMock(requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") + is( + Strings.format( + "Failed to execute task for inference id [id] because the request service [%s] queue is full", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); settings.setQueueCapacity(0); @@ -525,7 +536,133 @@ public class RequestExecutorServiceTests extends ESTestCase { executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE)); + assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE)); + } + + public void testDoesNotExecuteTask_WhenCannotReserveTokens() { + var mockRateLimiter = mock(RateLimiter.class); + RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter; + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + rateLimiterCreator + ); + var requestManager = RequestManagerTests.createMock(requestSender); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + doAnswer(invocation -> { + service.shutdown(); + return TimeValue.timeValueDays(1); + }).when(mockRateLimiter).timeToReserve(anyInt()); + + service.start(); + + verifyNoInteractions(requestSender); + } + + public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() { + var mockRateLimiter = mock(RateLimiter.class); + when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0)); + + RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter; + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + rateLimiterCreator + ); + var requestManager = RequestManagerTests.createMock(requestSender); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0)); + + doAnswer(invocation -> { + service.shutdown(); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + service.start(); + + verify(requestSender, times(1)).send(any(), any(), any(), any(), any(), any()); + } + + public void testRemovesRateLimitGroup_AfterStaleDuration() { + var now = Instant.now(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1)); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + clock, + RequestExecutorService.DEFAULT_SLEEPER, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR + ); + var requestManager = RequestManagerTests.createMock(requestSender, "id1"); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + assertThat(service.numberOfRateLimitGroups(), is(1)); + // the time is moved to after the stale duration, so now we should remove this grouping + when(clock.instant()).thenReturn(now.plus(Duration.ofDays(2))); + service.removeStaleGroupings(); + assertThat(service.numberOfRateLimitGroups(), is(0)); + + var requestManager2 = RequestManagerTests.createMock(requestSender, "id2"); + service.execute(requestManager2, new DocumentsOnlyInput(List.of()), null, listener); + + assertThat(service.numberOfRateLimitGroups(), is(1)); + } + + public void testStartsCleanupThread() { + var mockThreadPool = mock(ThreadPool.class); + + when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class)); + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1)); + var service = new RequestExecutorService( + mockThreadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR + ); + + service.shutdown(); + service.start(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TimeValue.class); + verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), argument.capture(), any()); + assertThat(argument.getValue(), is(TimeValue.timeValueDays(1))); } private Future submitShutdownRequest( @@ -552,12 +689,6 @@ public class RequestExecutorServiceTests extends ESTestCase { } private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) { - return new RequestExecutorService( - "test_service", - threadPool, - startupLatch, - createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(requestSender) - ); + return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java similarity index 56% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java index 31297ed432ef..291de740aca3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.RequestTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; @@ -21,34 +22,47 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ExecutableRequestCreatorTests { +public class RequestManagerTests { public static RequestManager createMock() { - var mockCreator = mock(RequestManager.class); - when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {}); + return createMock(mock(RequestSender.class)); + } - return mockCreator; + public static RequestManager createMock(String inferenceEntityId) { + return createMock(mock(RequestSender.class), inferenceEntityId); } public static RequestManager createMock(RequestSender requestSender) { - return createMock(requestSender, "id"); + return createMock(requestSender, "id", new RateLimitSettings(1)); } - public static RequestManager createMock(RequestSender requestSender, String modelId) { - var mockCreator = mock(RequestManager.class); + public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) { + return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1)); + } + + public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) { + var mockManager = mock(RequestManager.class); doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[5]; - return (Runnable) () -> requestSender.send( + ActionListener listener = (ActionListener) invocation.getArguments()[4]; + requestSender.send( mock(Logger.class), - RequestTests.mockRequest(modelId), + RequestTests.mockRequest(inferenceEntityId), HttpClientContext.create(), () -> false, mock(ResponseHandler.class), listener ); - }).when(mockCreator).create(any(), anyList(), any(), any(), any(), any()); - return mockCreator; + return Void.TYPE; + }).when(mockManager).execute(any(), anyList(), any(), any(), any()); + + // just return something consistent so the hashing works + when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId); + + when(mockManager.rateLimitSettings()).thenReturn(settings); + when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId); + + return mockManager; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java deleted file mode 100644 index 55965bc2354d..000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.http.client.protocol.HttpClientContext; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -public class SingleRequestManagerTests extends ESTestCase { - public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() { - var requestCreator = mock(RequestManager.class); - var request = mock(InferenceRequest.class); - when(request.getRequestCreator()).thenReturn(requestCreator); - - new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create()); - verifyNoInteractions(requestCreator); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index ee3403492c42..974b31e73b49 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +58,7 @@ public class SenderServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -67,7 +66,7 @@ public class SenderServiceTests extends ESTestCase { listener.actionGet(TIMEOUT); verify(sender, times(1)).start(); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); } verify(sender, times(1)).close(); @@ -79,7 +78,7 @@ public class SenderServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -89,7 +88,7 @@ public class SenderServiceTests extends ESTestCase { service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(2)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 5869366ac2e2..cacbba82446f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -76,7 +76,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -819,7 +818,7 @@ public class AzureAiStudioServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -841,7 +840,7 @@ public class AzureAiStudioServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java index 79d6e384d769..d46a5f190017 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java @@ -112,7 +112,8 @@ public class AzureAiStudioChatCompletionServiceSettingsTests extends ESTestCase String xContentResult = Strings.toString(builder); assertThat(xContentResult, CoreMatchers.is(""" - {"target":"target_value","provider":"openai","endpoint_type":"token"}""")); + {"target":"target_value","provider":"openai","endpoint_type":"token",""" + """ + "rate_limit":{"requests_per_minute":3}}""")); } public static HashMap createRequestSettingsMap(String target, String provider, String endpointType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java index 283bfa1490df..a592dd6e1f95 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java @@ -295,7 +295,7 @@ public class AzureAiStudioEmbeddingsServiceSettingsTests extends ESTestCase { assertThat(xContentResult, CoreMatchers.is(""" {"target":"target_value","provider":"openai","endpoint_type":"token",""" + """ - "dimensions":1024,"max_input_tokens":512}""")); + "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}""")); } public static HashMap createRequestSettingsMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 9fe8b472b22a..bb3407056d57 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -594,7 +593,7 @@ public class AzureOpenAiServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -616,7 +615,7 @@ public class AzureOpenAiServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java index 46e514c8b16c..797cad8f300a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; import java.io.IOException; @@ -46,7 +47,8 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria AzureOpenAiServiceFields.API_VERSION, apiVersion ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); @@ -63,18 +65,6 @@ public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSeria {"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}""")); } - public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { - var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}""")); - } - @Override protected Writeable.Reader instanceReader() { return AzureOpenAiCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java index f4c6f9b2a4f0..cbb9eea22380 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java @@ -389,7 +389,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria "dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}""")); } - public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_RateLimit() throws IOException { + public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser() throws IOException { var entity = new AzureOpenAiEmbeddingsServiceSettings( "resource", "deployment", @@ -408,7 +408,7 @@ public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSeria assertThat(xContentResult, is(""" {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ - "dimensions":1024,"max_input_tokens":512}""")); + "dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":1}}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f06fee4b0b9c..902d96be2973 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -613,7 +612,7 @@ public class CohereServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -635,7 +634,7 @@ public class CohereServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java index aac04e301ece..b9fc7ee7b995 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.HashMap; @@ -28,7 +29,8 @@ public class CohereCompletionModelTests extends ESTestCase { "service", new HashMap<>(Map.of()), new HashMap<>(Map.of("model", "overridden model")), - null + null, + ConfigurationParseContext.PERSISTENT ); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java index f4cab3c2b0f1..ed8bc90d3214 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -34,7 +35,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin var model = "model"; var serviceSettings = CohereCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)) + new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null))); @@ -55,7 +57,8 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)) ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute)))); @@ -72,18 +75,6 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializin {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WithFilteredObject_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"url","model_id":"model"}""")); - } - @Override protected Writeable.Reader instanceReader() { return CohereCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java index 6f8fe6344b57..73ebd6c6c050 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -331,21 +331,6 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin "rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new CohereEmbeddingsServiceSettings( - new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)), - CohereEmbeddingType.INT8 - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - assertThat(xContentResult, is(""" - {"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """ - "embedding_type":"byte"}""")); - } - @Override protected Writeable.Reader instanceReader() { return CohereEmbeddingsServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java index 4943ddf74fda..1ce5a9fb1283 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java @@ -51,20 +51,6 @@ public class CohereRerankServiceSettingsTests extends AbstractWireSerializingTes "rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new CohereRerankServiceSettings( - new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)) - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - // TODO we probably shouldn't allow configuring these fields for reranking - assertThat(xContentResult, is(""" - {"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id"}""")); - } - @Override protected Writeable.Reader instanceReader() { return CohereRerankServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 32e912ff8529..110276e63d07 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.hasSize; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -494,7 +493,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -516,7 +515,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java index 025317fbe025..f4c13db78c4b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URISyntaxException; @@ -28,7 +29,8 @@ public class GoogleAiStudioCompletionModelTests extends ESTestCase { "service", new HashMap<>(Map.of("model_id", "model")), new HashMap<>(Map.of()), - null + null, + ConfigurationParseContext.PERSISTENT ); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java index 46e6e60af493..6652af26e09e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -31,7 +32,10 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe public void testFromMap_Request_CreatesSettingsCorrectly() { var model = "some model"; - var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model))); + var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT + ); assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null))); } @@ -47,18 +51,6 @@ public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSe {"model_id":"model","rate_limit":{"requests_per_minute":360}}""")); } - public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { - var entity = new GoogleAiStudioCompletionServiceSettings("model", null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"model_id":"model"}""")); - } - @Override protected Writeable.Reader instanceReader() { return GoogleAiStudioCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java index b5fbd28b476b..cc195333adfd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -55,7 +56,8 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe ServiceFields.SIMILARITY, similarity.toString() ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null))); @@ -80,23 +82,6 @@ public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSe }""")); } - public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { - var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model_id":"model", - "max_input_tokens": 1024, - "dimensions": 8, - "similarity": "dot_product" - }""")); - } - @Override protected Writeable.Reader instanceReader() { return GoogleAiStudioEmbeddingsServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 398b21312a03..fd7e1b48b7e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.junit.After; import org.junit.Before; @@ -33,7 +34,6 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +59,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -81,7 +81,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } @@ -111,7 +111,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase { TaskType taskType, Map serviceSettings, Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ) { return null; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java index 91b91593adee..04e9697b0887 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -57,7 +58,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest var dims = 384; var maxInputTokens = 128; { - var serviceSettings = HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))); + var serviceSettings = HuggingFaceServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url)), + ConfigurationParseContext.PERSISTENT + ); assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url))); } { @@ -73,7 +77,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest ServiceFields.MAX_INPUT_TOKENS, maxInputTokens ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( serviceSettings, @@ -95,7 +100,8 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( serviceSettings, @@ -105,7 +111,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest } public void testFromMap_MissingUrl_ThrowsError() { - var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceServiceSettings.fromMap(new HashMap<>())); + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT) + ); assertThat( thrownException.getMessage(), @@ -118,7 +127,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest public void testFromMap_EmptyUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, ""))) + () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT) ); assertThat( @@ -136,7 +145,7 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest var url = "https://www.abc^.com"; var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))) + () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT) ); assertThat( @@ -152,7 +161,10 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest var similarity = "by_size"; var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity))) + () -> HuggingFaceServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -175,18 +187,6 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest {"url":"url","rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new HuggingFaceServiceSettings(ServiceUtils.createUri("url"), null, null, null, new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = org.elasticsearch.common.Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"url"}""")); - } - @Override protected Writeable.Reader instanceReader() { return HuggingFaceServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java index 57f9c59b65e1..2a44429687fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -32,7 +33,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin public void testFromMap() { var url = "https://www.abc.com"; - var serviceSettings = HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url))); + var serviceSettings = HuggingFaceElserServiceSettings.fromMap( + new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)), + ConfigurationParseContext.PERSISTENT + ); assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings)); } @@ -40,7 +44,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin public void testFromMap_EmptyUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, ""))) + () -> HuggingFaceElserServiceSettings.fromMap( + new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -55,7 +62,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin } public void testFromMap_MissingUrl_ThrowsError() { - var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>())); + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT) + ); assertThat( thrownException.getMessage(), @@ -72,7 +82,10 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin var url = "https://www.abc^.com"; var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url))) + () -> HuggingFaceElserServiceSettings.fromMap( + new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -98,18 +111,6 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin {"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new HuggingFaceElserServiceSettings(ServiceUtils.createUri("url"), new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = org.elasticsearch.common.Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"url","max_input_tokens":512}""")); - } - @Override protected Writeable.Reader instanceReader() { return HuggingFaceElserServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 3ead273e7811..624b24e61134 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -393,7 +392,7 @@ public class MistralServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -415,7 +414,7 @@ public class MistralServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java index 13f43a5f31ad..076986acdcee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java @@ -98,18 +98,6 @@ public class MistralEmbeddingsServiceSettingsTests extends ESTestCase { "rate_limit":{"requests_per_minute":3}}""")); } - public void testToFilteredXContent_WritesFilteredValues() throws IOException { - var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, CoreMatchers.is(""" - {"model":"model_name","dimensions":1024,"max_input_tokens":512}""")); - } - public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { var outputBuffer = new BytesStreamOutput(); var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index cbac29c45277..41995235565d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -675,7 +674,7 @@ public class OpenAiServiceTests extends ESTestCase { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -697,7 +696,7 @@ public class OpenAiServiceTests extends ESTestCase { is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java index 186ca8942641..051a9bc6d9be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; @@ -48,7 +49,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial ServiceFields.MAX_INPUT_TOKENS, maxInputTokens ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( @@ -77,7 +79,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit)) ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( @@ -101,7 +104,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial ServiceFields.MAX_INPUT_TOKENS, maxInputTokens ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertNull(serviceSettings.uri()); @@ -113,7 +117,10 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial public void testFromMap_EmptyUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, - () -> OpenAiChatCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model"))) + () -> OpenAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -132,7 +139,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial var maxInputTokens = 8192; var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)) + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)), + ConfigurationParseContext.PERSISTENT ); assertNull(serviceSettings.uri()); @@ -144,7 +152,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial var thrownException = expectThrows( ValidationException.class, () -> OpenAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")) + new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT ) ); @@ -164,7 +173,8 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial var thrownException = expectThrows( ValidationException.class, () -> OpenAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")) + new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT ) ); @@ -213,19 +223,6 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial {"model_id":"model","rate_limit":{"requests_per_minute":500}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new OpenAiChatCompletionServiceSettings("model", "url", "org", 1024, new RateLimitSettings(2)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = org.elasticsearch.common.Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"model_id":"model","url":"url","organization_id":"org",""" + """ - "max_input_tokens":1024}""")); - } - @Override protected Writeable.Reader instanceReader() { return OpenAiChatCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java index 438f895fe48a..cc0004a2d678 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java @@ -406,7 +406,7 @@ public class OpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializin assertThat(xContentResult, is(""" {"model_id":"model","url":"url","organization_id":"org","similarity":"dot_product",""" + """ - "dimensions":1,"max_input_tokens":2}""")); + "dimensions":1,"max_input_tokens":2,"rate_limit":{"requests_per_minute":3000}}""")); } public void testToFilteredXContent_WritesAllValues_WithSpecifiedRateLimit() throws IOException { @@ -428,7 +428,7 @@ public class OpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializin assertThat(xContentResult, is(""" {"model_id":"model","url":"url","organization_id":"org","similarity":"dot_product",""" + """ - "dimensions":1,"max_input_tokens":2}""")); + "dimensions":1,"max_input_tokens":2,"rate_limit":{"requests_per_minute":2000}}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java index cdee7c452ff5..7e3bdd6b8e5d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.settings; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; @@ -14,6 +15,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; import java.util.HashMap; @@ -49,7 +51,7 @@ public class RateLimitSettingsTests extends AbstractWireSerializingTestCase settings = new HashMap<>( Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) ); - var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation); + var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(100))); assertTrue(validation.validationErrors().isEmpty()); @@ -60,7 +62,7 @@ public class RateLimitSettingsTests extends AbstractWireSerializingTestCase settings = new HashMap<>( Map.of("abc", new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) ); - var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation); + var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(1))); assertTrue(validation.validationErrors().isEmpty()); @@ -69,12 +71,24 @@ public class RateLimitSettingsTests extends AbstractWireSerializingTestCase settings = new HashMap<>(Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of("abc", 100)))); - var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation); + var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(1))); assertTrue(validation.validationErrors().isEmpty()); } + public void testOf_ThrowsException_WithUnknownField_InRequestContext() { + var validation = new ValidationException(); + Map settings = new HashMap<>(Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of("abc", 100)))); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.REQUEST) + ); + + assertThat(exception.getMessage(), is("Model configuration contains settings [{abc=100}] unknown to the [test] service")); + } + public void testToXContent() throws IOException { var settings = new RateLimitSettings(100);