mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
(cherry picked from commit c6a00589b0
)
Co-authored-by: James Gowdy <jgowdy@elastic.co>
This commit is contained in:
parent
aa4540df9f
commit
ed16c6cd52
1 changed files with 30 additions and 42 deletions
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import React, { FC } from 'react';
|
||||
import React, { FC, useMemo } from 'react';
|
||||
|
||||
import { NerInference } from './models/ner';
|
||||
import { QuestionAnsweringInference } from './models/question_answering';
|
||||
|
@ -28,53 +28,41 @@ import { useMlApiContext } from '../../../contexts/kibana';
|
|||
import { InferenceInputForm } from './models/inference_input_form';
|
||||
|
||||
interface Props {
|
||||
model: estypes.MlTrainedModelConfig | null;
|
||||
model: estypes.MlTrainedModelConfig;
|
||||
}
|
||||
|
||||
export const SelectedModel: FC<Props> = ({ model }) => {
|
||||
const { trainedModels } = useMlApiContext();
|
||||
|
||||
if (model === null) {
|
||||
const inferrer = useMemo(() => {
|
||||
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
|
||||
const taskType = Object.keys(model.inference_config)[0];
|
||||
|
||||
switch (taskType) {
|
||||
case SUPPORTED_PYTORCH_TASKS.NER:
|
||||
return new NerInference(trainedModels, model);
|
||||
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
|
||||
return new TextClassificationInference(trainedModels, model);
|
||||
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
|
||||
return new ZeroShotClassificationInference(trainedModels, model);
|
||||
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
|
||||
return new TextEmbeddingInference(trainedModels, model);
|
||||
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
|
||||
return new FillMaskInference(trainedModels, model);
|
||||
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
|
||||
return new QuestionAnsweringInference(trainedModels, model);
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
|
||||
return new LangIdentInference(trainedModels, model);
|
||||
}
|
||||
}, [model, trainedModels]);
|
||||
|
||||
if (inferrer === undefined) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
|
||||
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.NER) {
|
||||
const inferrer = new NerInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
|
||||
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION) {
|
||||
const inferrer = new TextClassificationInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
|
||||
if (
|
||||
Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION
|
||||
) {
|
||||
const inferrer = new ZeroShotClassificationInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
|
||||
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
|
||||
const inferrer = new TextEmbeddingInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
|
||||
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.FILL_MASK) {
|
||||
const inferrer = new FillMaskInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
|
||||
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING) {
|
||||
const inferrer = new QuestionAnsweringInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
}
|
||||
if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
|
||||
const inferrer = new LangIdentInference(trainedModels, model);
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
}
|
||||
|
||||
return null;
|
||||
return <InferenceInputForm inferrer={inferrer} />;
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue