[ML] Fix model test flyout reload (#144318) (#144388)

(cherry picked from commit c6a00589b0)

Co-authored-by: James Gowdy <jgowdy@elastic.co>
This commit is contained in:
Kibana Machine 2022-11-02 06:19:19 -04:00 committed by GitHub
parent aa4540df9f
commit ed16c6cd52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,7 @@
*/
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import React, { FC } from 'react';
import React, { FC, useMemo } from 'react';
import { NerInference } from './models/ner';
import { QuestionAnsweringInference } from './models/question_answering';
@ -28,53 +28,41 @@ import { useMlApiContext } from '../../../contexts/kibana';
import { InferenceInputForm } from './models/inference_input_form';
interface Props {
model: estypes.MlTrainedModelConfig | null;
model: estypes.MlTrainedModelConfig;
}
export const SelectedModel: FC<Props> = ({ model }) => {
const { trainedModels } = useMlApiContext();
if (model === null) {
const inferrer = useMemo(() => {
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
const taskType = Object.keys(model.inference_config)[0];
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
return new NerInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
return new TextClassificationInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
return new ZeroShotClassificationInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
return new TextEmbeddingInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
return new FillMaskInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
return new QuestionAnsweringInference(trainedModels, model);
default:
break;
}
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
return new LangIdentInference(trainedModels, model);
}
}, [model, trainedModels]);
if (inferrer === undefined) {
return null;
}
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.NER) {
const inferrer = new NerInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION) {
const inferrer = new TextClassificationInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
if (
Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION
) {
const inferrer = new ZeroShotClassificationInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
const inferrer = new TextEmbeddingInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.FILL_MASK) {
const inferrer = new FillMaskInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING) {
const inferrer = new QuestionAnsweringInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
}
if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
const inferrer = new LangIdentInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
return null;
return <InferenceInputForm inferrer={inferrer} />;
};