mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-19 04:45:07 -04:00
[ML] Adding missing onFailure call for Inference API start model request (#126930)
* Adding missing onFailure call * Update docs/changelog/126930.yaml
This commit is contained in:
parent
7d6fda5b06
commit
e42c118ec6
4 changed files with 55 additions and 1 deletions
5
docs/changelog/126930.yaml
Normal file
5
docs/changelog/126930.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 126930
|
||||
summary: Adding missing `onFailure` call for Inference API start model request
|
||||
area: Machine Learning
|
||||
type: bug
|
||||
issues: []
|
|
@ -106,7 +106,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
|
|||
})
|
||||
.<Boolean>andThen((l2, modelDidPut) -> {
|
||||
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
|
||||
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
|
||||
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
|
||||
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
|
||||
})
|
||||
.addListener(finalListener);
|
||||
|
|
|
@ -105,6 +105,8 @@ public abstract class ElasticsearchInternalModel extends Model {
|
|||
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
|
||||
// Deployment is already started
|
||||
listener.onResponse(Boolean.TRUE);
|
||||
} else {
|
||||
listener.onFailure(e);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.elasticsearch.inference.Model;
|
|||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
|
@ -49,13 +50,16 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
|||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.ml.MachineLearningField;
|
||||
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
||||
|
@ -1858,6 +1862,49 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
|
||||
var model = new ElserInternalModel(
|
||||
"inference_id",
|
||||
TaskType.SPARSE_EMBEDDING,
|
||||
"elasticsearch",
|
||||
new ElserInternalServiceSettings(
|
||||
new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
|
||||
),
|
||||
new ElserMlNodeTaskSettings(),
|
||||
null
|
||||
);
|
||||
|
||||
var client = mock(Client.class);
|
||||
when(client.threadPool()).thenReturn(threadPool);
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
|
||||
var builder = GetTrainedModelsAction.Response.builder();
|
||||
builder.setModels(List.of(mock(TrainedModelConfig.class)));
|
||||
builder.setTotalCount(1);
|
||||
|
||||
listener.onResponse(builder.build());
|
||||
return Void.TYPE;
|
||||
}).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
|
||||
listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
|
||||
return Void.TYPE;
|
||||
}).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
|
||||
|
||||
try (var service = createService(client)) {
|
||||
var actionListener = new PlainActionFuture<Boolean>();
|
||||
service.start(model, TimeValue.timeValueSeconds(30), actionListener);
|
||||
var exception = expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
|
||||
);
|
||||
|
||||
assertThat(exception.getMessage(), is("failed"));
|
||||
}
|
||||
}
|
||||
|
||||
private ElasticsearchInternalService createService(Client client) {
|
||||
var cs = mock(ClusterService.class);
|
||||
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
|
||||
|
|
Loading…
Add table
Reference in a new issue