[ML] Adding missing onFailure call for Inference API start model request (#126930)

* Adding missing onFailure call

* Update docs/changelog/126930.yaml
This commit is contained in:
Jonathan Buttner 2025-04-16 14:07:13 -04:00 committed by GitHub
parent 7d6fda5b06
commit e42c118ec6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 55 additions and 1 deletions

View file

@ -0,0 +1,5 @@
pr: 126930
summary: Adding missing `onFailure` call for Inference API start model request
area: Machine Learning
type: bug
issues: []

View file

@ -106,7 +106,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
.addListener(finalListener);

View file

@ -105,6 +105,8 @@ public abstract class ElasticsearchInternalModel extends Model {
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
// Deployment is already started
listener.onResponse(Boolean.TRUE);
} else {
listener.onFailure(e);
}
return;
}

View file

@ -37,6 +37,7 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ParseField;
@ -49,13 +50,16 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@ -1858,6 +1862,49 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
}
}
public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
var model = new ElserInternalModel(
"inference_id",
TaskType.SPARSE_EMBEDDING,
"elasticsearch",
new ElserInternalServiceSettings(
new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
),
new ElserMlNodeTaskSettings(),
null
);
var client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
var builder = GetTrainedModelsAction.Response.builder();
builder.setModels(List.of(mock(TrainedModelConfig.class)));
builder.setTotalCount(1);
listener.onResponse(builder.build());
return Void.TYPE;
}).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
doAnswer(invocationOnMock -> {
ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
return Void.TYPE;
}).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
try (var service = createService(client)) {
var actionListener = new PlainActionFuture<Boolean>();
service.start(model, TimeValue.timeValueSeconds(30), actionListener);
var exception = expectThrows(
ElasticsearchStatusException.class,
() -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
);
assertThat(exception.getMessage(), is("failed"));
}
}
private ElasticsearchInternalService createService(Client client) {
var cs = mock(ClusterService.class);
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));