[ML] Only suppoer ner pytorch model testing in UI (#129103)

This commit is contained in:
James Gowdy 2022-04-01 15:00:03 +01:00 committed by GitHub
parent 862dac8e24
commit a7b239f8d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 9 deletions

View file

@ -18,5 +18,14 @@ export const TRAINED_MODEL_TYPE = {
TREE_ENSEMBLE: 'tree_ensemble',
LANG_IDENT: 'lang_ident',
} as const;
export type TrainedModelType = typeof TRAINED_MODEL_TYPE[keyof typeof TRAINED_MODEL_TYPE];
export const SUPPORTED_PYTORCH_TASKS = {
NER: 'ner',
// ZERO_SHOT_CLASSIFICATION: 'zero_shot_classification',
// CLASSIFICATION_LABELS: 'classification_labels',
// TEXT_CLASSIFICATION: 'text_classification',
// TEXT_EMBEDDING: 'text_embedding',
} as const;
export type SupportedPytorchTasksType =
typeof SUPPORTED_PYTORCH_TASKS[keyof typeof SUPPORTED_PYTORCH_TASKS];

View file

@ -13,7 +13,10 @@ import type { FormattedNerResp } from './models/ner';
import { LangIdentOutput, LangIdentInference } from './models/lang_ident';
import type { FormattedLangIdentResp } from './models/lang_ident';
import { TRAINED_MODEL_TYPE } from '../../../../../common/constants/trained_models';
import {
TRAINED_MODEL_TYPE,
SUPPORTED_PYTORCH_TASKS,
} from '../../../../../common/constants/trained_models';
import { useMlApiContext } from '../../../contexts/kibana';
import { InferenceInputForm } from './models/inference_input_form';
@ -28,7 +31,10 @@ export const SelectedModel: FC<Props> = ({ model }) => {
return null;
}
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
if (
model.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.NER
) {
const inferrer = new NerInference(trainedModels, model);
return (
<InferenceInputForm

View file

@ -6,13 +6,25 @@
*/
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { TRAINED_MODEL_TYPE } from '../../../../../common/constants/trained_models';
import {
TRAINED_MODEL_TYPE,
SUPPORTED_PYTORCH_TASKS,
} from '../../../../../common/constants/trained_models';
import type { SupportedPytorchTasksType } from '../../../../../common/constants/trained_models';
const TESTABLE_MODEL_TYPES: estypes.MlTrainedModelType[] = [
TRAINED_MODEL_TYPE.PYTORCH,
TRAINED_MODEL_TYPE.LANG_IDENT,
];
const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);
export function isTestable(model: estypes.MlTrainedModelConfig) {
return model.model_type && TESTABLE_MODEL_TYPES.includes(model.model_type);
if (
model.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
PYTORCH_TYPES.includes(Object.keys(model.inference_config)[0] as SupportedPytorchTasksType)
) {
return true;
}
if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
return true;
}
return false;
}