[ES|QL] Traversal high level APIs (#189957)

## Summary

Closes https://github.com/elastic/kibana/issues/182255

Implements high-level AST node finding APIs, which satisfy requirements
in https://github.com/elastic/kibana/issues/182255

- Introduces `"visitAny"` callback, which capture any AST node (not
captured by a more specific callback).
- `Walker.find(ast, predicate)` — finds the first AST node that
satisfies a `predicate` condition.
- `Walker.findAll(ast, predicate)` — finds the all AST nodes that
satisfies a `predicate` condition.
- `Walker.match(ast, template)` — finds the first AST node that
*matches* the provided node `template`.
- `Walker.matchAll(ast, template)` — finds all AST nodes that
*match* the provided node `template`.


For example, here is how you find all `agg1` and `agg2` function nodes:

```ts
const nodes = Walker.matchAll(ast, { type: 'function', name: ['agg1', 'agg2'] });
```

### Checklist

Delete any items that are not applicable to this PR.

- [x]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials

### For maintainers

- [x] This was checked for breaking API changes and was [labeled
appropriately](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)
This commit is contained in:
Vadim Kibana 2024-08-06 19:52:51 +02:00 committed by GitHub
parent 36afc2fdbd
commit 7332211c72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 545 additions and 19 deletions

View file

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
import { ESQLProperNode } from '../types';
export type NodeMatchTemplateKey<V> = V | V[] | RegExp;
export type NodeMatchTemplate = {
[K in keyof ESQLProperNode]?: NodeMatchTemplateKey<ESQLProperNode[K]>;
};
/**
* Creates a predicate function which matches a single AST node against a
* template object. The template object should have the same keys as the
* AST node, and the values should be:
*
* - An array matches if the node key is in the array.
* - A RegExp matches if the node key matches the RegExp.
* - Any other value matches if the node key is triple-equal to the value.
*
* @param template Template from which to create a predicate function.
* @returns A predicate function that matches nodes against the template.
*/
export const templateToPredicate = (
template: NodeMatchTemplate
): ((node: ESQLProperNode) => boolean) => {
const keys = Object.keys(template) as Array<keyof ESQLProperNode>;
const predicate = (child: ESQLProperNode) => {
for (const key of keys) {
const matcher = template[key];
if (matcher instanceof Array) {
if (!(matcher as any[]).includes(child[key])) {
return false;
}
} else if (matcher instanceof RegExp) {
if (!matcher.test(String(child[key]))) {
return false;
}
} else if (child[key] !== matcher) {
return false;
}
}
return true;
};
return predicate;
};

View file

@ -81,6 +81,24 @@ describe('structurally can walk all nodes', () => {
]);
});
test('"visitAny" can capture command nodes', () => {
const { ast } = getAstAndSyntaxErrors('FROM index | STATS a = 123 | WHERE 123 | LIMIT 10');
const commands: ESQLCommand[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'command') commands.push(node);
},
});
expect(commands.map(({ name }) => name).sort()).toStrictEqual([
'from',
'limit',
'stats',
'where',
]);
});
describe('command options', () => {
test('can visit command options', () => {
const { ast } = getAstAndSyntaxErrors('FROM index METADATA _index');
@ -93,19 +111,47 @@ describe('structurally can walk all nodes', () => {
expect(options.length).toBe(1);
expect(options[0].name).toBe('metadata');
});
test('"visitAny" can capture an options node', () => {
const { ast } = getAstAndSyntaxErrors('FROM index METADATA _index');
const options: ESQLCommandOption[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'option') options.push(node);
},
});
expect(options.length).toBe(1);
expect(options[0].name).toBe('metadata');
});
});
describe('command mode', () => {
test('visits "mode" nodes', () => {
const { ast } = getAstAndSyntaxErrors('FROM index | ENRICH a:b');
const options: ESQLCommandMode[] = [];
const modes: ESQLCommandMode[] = [];
walk(ast, {
visitCommandMode: (opt) => options.push(opt),
visitCommandMode: (opt) => modes.push(opt),
});
expect(options.length).toBe(1);
expect(options[0].name).toBe('a');
expect(modes.length).toBe(1);
expect(modes[0].name).toBe('a');
});
test('"visitAny" can capture a mode node', () => {
const { ast } = getAstAndSyntaxErrors('FROM index | ENRICH a:b');
const modes: ESQLCommandMode[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'mode') modes.push(node);
},
});
expect(modes.length).toBe(1);
expect(modes[0].name).toBe('a');
});
});
@ -123,6 +169,20 @@ describe('structurally can walk all nodes', () => {
expect(sources[0].name).toBe('index');
});
test('"visitAny" can capture a source node', () => {
const { ast } = getAstAndSyntaxErrors('FROM index');
const sources: ESQLSource[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'source') sources.push(node);
},
});
expect(sources.length).toBe(1);
expect(sources[0].name).toBe('index');
});
test('iterates through all sources', () => {
const { ast } = getAstAndSyntaxErrors('METRICS index, index2, index3, index4');
const sources: ESQLSource[] = [];
@ -142,7 +202,7 @@ describe('structurally can walk all nodes', () => {
});
describe('columns', () => {
test('can through a single column', () => {
test('can walk through a single column', () => {
const query = 'ROW x = 1';
const { ast } = getAstAndSyntaxErrors(query);
const columns: ESQLColumn[] = [];
@ -159,6 +219,25 @@ describe('structurally can walk all nodes', () => {
]);
});
test('"visitAny" can capture a column', () => {
const query = 'ROW x = 1';
const { ast } = getAstAndSyntaxErrors(query);
const columns: ESQLColumn[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'column') columns.push(node);
},
});
expect(columns).toMatchObject([
{
type: 'column',
name: 'x',
},
]);
});
test('can walk through multiple columns', () => {
const query = 'FROM index | STATS a = 123, b = 456';
const { ast } = getAstAndSyntaxErrors(query);
@ -181,6 +260,52 @@ describe('structurally can walk all nodes', () => {
});
});
describe('functions', () => {
test('can walk through functions', () => {
const query = 'FROM a | STATS fn(1), agg(true)';
const { ast } = getAstAndSyntaxErrors(query);
const nodes: ESQLFunction[] = [];
walk(ast, {
visitFunction: (node) => nodes.push(node),
});
expect(nodes).toMatchObject([
{
type: 'function',
name: 'fn',
},
{
type: 'function',
name: 'agg',
},
]);
});
test('"visitAny" can capture function nodes', () => {
const query = 'FROM a | STATS fn(1), agg(true)';
const { ast } = getAstAndSyntaxErrors(query);
const nodes: ESQLFunction[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'function') nodes.push(node);
},
});
expect(nodes).toMatchObject([
{
type: 'function',
name: 'fn',
},
{
type: 'function',
name: 'agg',
},
]);
});
});
describe('literals', () => {
test('can walk a single literal', () => {
const query = 'ROW x = 1';
@ -301,6 +426,20 @@ describe('structurally can walk all nodes', () => {
]);
});
test('"visitAny" can capture a list literal', () => {
const query = 'ROW x = [1, 2]';
const { ast } = getAstAndSyntaxErrors(query);
const lists: ESQLList[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'list') lists.push(node);
},
});
expect(lists.length).toBe(1);
});
test('can walk plain literals inside list literal', () => {
const query = 'ROW x = [1, 2] + [3.3]';
const { ast } = getAstAndSyntaxErrors(query);
@ -492,7 +631,6 @@ describe('structurally can walk all nodes', () => {
test('can visit time interval nodes', () => {
const query = 'FROM index | STATS a = 123 BY 1h';
const { ast } = getAstAndSyntaxErrors(query);
const intervals: ESQLTimeInterval[] = [];
walk(ast, {
@ -507,6 +645,43 @@ describe('structurally can walk all nodes', () => {
},
]);
});
test('"visitAny" can capture time interval expressions', () => {
const query = 'FROM index | STATS a = 123 BY 1h';
const { ast } = getAstAndSyntaxErrors(query);
const intervals: ESQLTimeInterval[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'timeInterval') intervals.push(node);
},
});
expect(intervals).toMatchObject([
{
type: 'timeInterval',
quantity: 1,
unit: 'h',
},
]);
});
test('"visitAny" does not capture time interval node if type-specific callback provided', () => {
const query = 'FROM index | STATS a = 123 BY 1h';
const { ast } = getAstAndSyntaxErrors(query);
const intervals1: ESQLTimeInterval[] = [];
const intervals2: ESQLTimeInterval[] = [];
walk(ast, {
visitTimeIntervalLiteral: (node) => intervals1.push(node),
visitAny: (node) => {
if (node.type === 'timeInterval') intervals2.push(node);
},
});
expect(intervals1.length).toBe(1);
expect(intervals2.length).toBe(0);
});
});
describe('cast expression', () => {
@ -532,6 +707,30 @@ describe('structurally can walk all nodes', () => {
},
]);
});
test('"visitAny" can capture cast expression', () => {
const query = 'FROM index | STATS a = 123::integer';
const { ast } = getAstAndSyntaxErrors(query);
const casts: ESQLInlineCast[] = [];
walk(ast, {
visitAny: (node) => {
if (node.type === 'inlineCast') casts.push(node);
},
});
expect(casts).toMatchObject([
{
type: 'inlineCast',
castType: 'integer',
value: {
type: 'literal',
literalType: 'integer',
value: 123,
},
},
]);
});
});
});
});
@ -576,7 +775,7 @@ describe('Walker.commands()', () => {
});
});
describe('Walker.params', () => {
describe('Walker.params()', () => {
test('can collect all params', () => {
const query = 'ROW x = ?';
const { ast } = getAstAndSyntaxErrors(query);
@ -613,10 +812,195 @@ describe('Walker.params', () => {
});
});
describe('Walker.find()', () => {
test('can find a bucket() function', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const fn = Walker.find(
getAstAndSyntaxErrors(query).ast!,
(node) => node.type === 'function' && node.name === 'bucket'
);
expect(fn).toMatchObject({
type: 'function',
name: 'bucket',
});
});
test('finds the first "fn" function', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const fn = Walker.find(
getAstAndSyntaxErrors(query).ast!,
(node) => node.type === 'function' && node.name === 'fn'
);
expect(fn).toMatchObject({
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 1,
},
],
});
});
});
describe('Walker.findAll()', () => {
test('find all "fn" functions', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const list = Walker.findAll(
getAstAndSyntaxErrors(query).ast!,
(node) => node.type === 'function' && node.name === 'fn'
);
expect(list).toMatchObject([
{
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 1,
},
],
},
{
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 2,
},
],
},
]);
});
});
describe('Walker.match()', () => {
test('can find a bucket() function', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const fn = Walker.match(getAstAndSyntaxErrors(query).ast!, {
type: 'function',
name: 'bucket',
});
expect(fn).toMatchObject({
type: 'function',
name: 'bucket',
});
});
test('finds the first "fn" function', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const fn = Walker.match(getAstAndSyntaxErrors(query).ast!, { type: 'function', name: 'fn' });
expect(fn).toMatchObject({
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 1,
},
],
});
});
});
describe('Walker.matchAll()', () => {
test('find all "fn" functions', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const list = Walker.matchAll(getAstAndSyntaxErrors(query).ast!, {
type: 'function',
name: 'fn',
});
expect(list).toMatchObject([
{
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 1,
},
],
},
{
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 2,
},
],
},
]);
});
test('find all "fn" and "agg" functions', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const list = Walker.matchAll(getAstAndSyntaxErrors(query).ast!, {
type: 'function',
name: ['fn', 'agg'],
});
expect(list).toMatchObject([
{
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 1,
},
],
},
{
type: 'function',
name: 'fn',
args: [
{
type: 'literal',
value: 2,
},
],
},
{
type: 'function',
name: 'agg',
},
]);
});
test('find all functions which start with "b" or "a"', () => {
const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)';
const list = Walker.matchAll(getAstAndSyntaxErrors(query).ast!, {
type: 'function',
name: /^a|b/i,
});
expect(list).toMatchObject([
{
type: 'function',
name: 'bucket',
},
{
type: 'function',
name: 'agg',
},
]);
});
});
describe('Walker.hasFunction()', () => {
test('can find assignment expression', () => {
const query1 = 'METRICS source bucket(bytes, 1 hour)';
const query2 = 'METRICS source var0 = bucket(bytes, 1 hour)';
const query1 = 'FROM a | STATS bucket(bytes, 1 hour)';
const query2 = 'FROM b | STATS var0 = bucket(bytes, 1 hour)';
const has1 = Walker.hasFunction(getAstAndSyntaxErrors(query1).ast!, '=');
const has2 = Walker.hasFunction(getAstAndSyntaxErrors(query2).ast!, '=');

View file

@ -19,11 +19,13 @@ import type {
ESQLList,
ESQLLiteral,
ESQLParamLiteral,
ESQLProperNode,
ESQLSingleAstItem,
ESQLSource,
ESQLTimeInterval,
ESQLUnknownItem,
} from '../types';
import { NodeMatchTemplate, templateToPredicate } from './helpers';
type Node = ESQLAstNode | ESQLAstNode[];
@ -40,6 +42,13 @@ export interface WalkerOptions {
visitTimeIntervalLiteral?: (node: ESQLTimeInterval) => void;
visitInlineCast?: (node: ESQLInlineCast) => void;
visitUnknown?: (node: ESQLUnknownItem) => void;
/**
* Called for any node type that does not have a specific visitor.
*
* @param node Any valid AST node.
*/
visitAny?: (node: ESQLProperNode) => void;
}
export type WalkerAstNode = ESQLAstNode | ESQLAstNode[];
@ -102,6 +111,82 @@ export class Walker {
return params;
};
/**
* Finds and returns the first node that matches the search criteria.
*
* @param node AST node to start the search from.
* @param predicate A function that returns true if the node matches the search criteria.
* @returns The first node that matches the search criteria.
*/
public static readonly find = (
node: WalkerAstNode,
predicate: (node: ESQLProperNode) => boolean
): ESQLProperNode | undefined => {
let found: ESQLProperNode | undefined;
Walker.walk(node, {
visitAny: (child) => {
if (!found && predicate(child)) {
found = child;
}
},
});
return found;
};
/**
* Finds and returns all nodes that match the search criteria.
*
* @param node AST node to start the search from.
* @param predicate A function that returns true if the node matches the search criteria.
* @returns All nodes that match the search criteria.
*/
public static readonly findAll = (
node: WalkerAstNode,
predicate: (node: ESQLProperNode) => boolean
): ESQLProperNode[] => {
const list: ESQLProperNode[] = [];
Walker.walk(node, {
visitAny: (child) => {
if (predicate(child)) {
list.push(child);
}
},
});
return list;
};
/**
* Matches a single node against a template object. Returns the first node
* that matches the template.
*
* @param node AST node to match against the template.
* @param template Template object to match against the node.
* @returns The first node that matches the template
*/
public static readonly match = (
node: WalkerAstNode,
template: NodeMatchTemplate
): ESQLProperNode | undefined => {
const predicate = templateToPredicate(template);
return Walker.find(node, predicate);
};
/**
* Matches all nodes against a template object. Returns all nodes that match
* the template.
*
* @param node AST node to match against the template.
* @param template Template object to match against the node.
* @returns All nodes that match the template
*/
public static readonly matchAll = (
node: WalkerAstNode,
template: NodeMatchTemplate
): ESQLProperNode[] => {
const predicate = templateToPredicate(template);
return Walker.findAll(node, predicate);
};
/**
* Finds the first function that matches the predicate.
*
@ -161,7 +246,8 @@ export class Walker {
}
public walkCommand(node: ESQLAstCommand): void {
this.options.visitCommand?.(node);
const { options } = this;
(options.visitCommand ?? options.visitAny)?.(node);
switch (node.name) {
default: {
this.walk(node.args);
@ -171,7 +257,8 @@ export class Walker {
}
public walkOption(node: ESQLCommandOption): void {
this.options.visitCommandOption?.(node);
const { options } = this;
(options.visitCommandOption ?? options.visitAny)?.(node);
for (const child of node.args) {
this.walkAstItem(child);
}
@ -188,11 +275,13 @@ export class Walker {
}
public walkMode(node: ESQLCommandMode): void {
this.options.visitCommandMode?.(node);
const { options } = this;
(options.visitCommandMode ?? options.visitAny)?.(node);
}
public walkListLiteral(node: ESQLList): void {
this.options.visitListLiteral?.(node);
const { options } = this;
(options.visitListLiteral ?? options.visitAny)?.(node);
for (const value of node.values) {
this.walkAstItem(value);
}
@ -215,11 +304,11 @@ export class Walker {
break;
}
case 'source': {
options.visitSource?.(node);
(options.visitSource ?? options.visitAny)?.(node);
break;
}
case 'column': {
options.visitColumn?.(node);
(options.visitColumn ?? options.visitAny)?.(node);
break;
}
case 'literal': {
@ -231,22 +320,23 @@ export class Walker {
break;
}
case 'timeInterval': {
options.visitTimeIntervalLiteral?.(node);
(options.visitTimeIntervalLiteral ?? options.visitAny)?.(node);
break;
}
case 'inlineCast': {
options.visitInlineCast?.(node);
(options.visitInlineCast ?? options.visitAny)?.(node);
break;
}
case 'unknown': {
options.visitUnknown?.(node);
(options.visitUnknown ?? options.visitAny)?.(node);
break;
}
}
}
public walkFunction(node: ESQLFunction): void {
this.options.visitFunction?.(node);
const { options } = this;
(options.visitFunction ?? options.visitAny)?.(node);
const args = node.args;
const length = args.length;
for (let i = 0; i < length; i++) {