mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[ML] Only suppoer ner pytorch model testing in UI (#129103)
This commit is contained in:
parent
862dac8e24
commit
a7b239f8d9
3 changed files with 36 additions and 9 deletions
|
@ -18,5 +18,14 @@ export const TRAINED_MODEL_TYPE = {
|
|||
TREE_ENSEMBLE: 'tree_ensemble',
|
||||
LANG_IDENT: 'lang_ident',
|
||||
} as const;
|
||||
|
||||
export type TrainedModelType = typeof TRAINED_MODEL_TYPE[keyof typeof TRAINED_MODEL_TYPE];
|
||||
|
||||
export const SUPPORTED_PYTORCH_TASKS = {
|
||||
NER: 'ner',
|
||||
// ZERO_SHOT_CLASSIFICATION: 'zero_shot_classification',
|
||||
// CLASSIFICATION_LABELS: 'classification_labels',
|
||||
// TEXT_CLASSIFICATION: 'text_classification',
|
||||
// TEXT_EMBEDDING: 'text_embedding',
|
||||
} as const;
|
||||
export type SupportedPytorchTasksType =
|
||||
typeof SUPPORTED_PYTORCH_TASKS[keyof typeof SUPPORTED_PYTORCH_TASKS];
|
||||
|
|
|
@ -13,7 +13,10 @@ import type { FormattedNerResp } from './models/ner';
|
|||
import { LangIdentOutput, LangIdentInference } from './models/lang_ident';
|
||||
import type { FormattedLangIdentResp } from './models/lang_ident';
|
||||
|
||||
import { TRAINED_MODEL_TYPE } from '../../../../../common/constants/trained_models';
|
||||
import {
|
||||
TRAINED_MODEL_TYPE,
|
||||
SUPPORTED_PYTORCH_TASKS,
|
||||
} from '../../../../../common/constants/trained_models';
|
||||
import { useMlApiContext } from '../../../contexts/kibana';
|
||||
import { InferenceInputForm } from './models/inference_input_form';
|
||||
|
||||
|
@ -28,7 +31,10 @@ export const SelectedModel: FC<Props> = ({ model }) => {
|
|||
return null;
|
||||
}
|
||||
|
||||
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
|
||||
if (
|
||||
model.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
|
||||
Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.NER
|
||||
) {
|
||||
const inferrer = new NerInference(trainedModels, model);
|
||||
return (
|
||||
<InferenceInputForm
|
||||
|
|
|
@ -6,13 +6,25 @@
|
|||
*/
|
||||
|
||||
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
|
||||
import { TRAINED_MODEL_TYPE } from '../../../../../common/constants/trained_models';
|
||||
import {
|
||||
TRAINED_MODEL_TYPE,
|
||||
SUPPORTED_PYTORCH_TASKS,
|
||||
} from '../../../../../common/constants/trained_models';
|
||||
import type { SupportedPytorchTasksType } from '../../../../../common/constants/trained_models';
|
||||
|
||||
const TESTABLE_MODEL_TYPES: estypes.MlTrainedModelType[] = [
|
||||
TRAINED_MODEL_TYPE.PYTORCH,
|
||||
TRAINED_MODEL_TYPE.LANG_IDENT,
|
||||
];
|
||||
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);
|
||||
|
||||
export function isTestable(model: estypes.MlTrainedModelConfig) {
|
||||
return model.model_type && TESTABLE_MODEL_TYPES.includes(model.model_type);
|
||||
if (
|
||||
model.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
|
||||
PYTORCH_TYPES.includes(Object.keys(model.inference_config)[0] as SupportedPytorchTasksType)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue