diff --git a/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java b/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java index fe0f82560894..ce951715939f 100644 --- a/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java +++ b/distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java @@ -167,12 +167,16 @@ final class SystemJvmOptions { } catch (IOException e) { throw new IllegalStateException("Failed to list entitlement jars in: " + dir, e); } + // We instrument classes in these modules to call the bridge. Because the bridge gets patched + // into java.base, we must export the bridge from java.base to these modules. + String modulesContainingEntitlementInstrumentation = "java.logging"; return Stream.of( "-Des.entitlements.enabled=true", "-XX:+EnableDynamicAgentLoading", "-Djdk.attach.allowAttachSelf=true", "--patch-module=java.base=" + bridgeJar, - "--add-exports=java.base/org.elasticsearch.entitlement.bridge=org.elasticsearch.entitlement" + "--add-exports=java.base/org.elasticsearch.entitlement.bridge=org.elasticsearch.entitlement," + + modulesContainingEntitlementInstrumentation ); } } diff --git a/docs/changelog/118774.yaml b/docs/changelog/118774.yaml new file mode 100644 index 000000000000..cbd1ca82d1c5 --- /dev/null +++ b/docs/changelog/118774.yaml @@ -0,0 +1,5 @@ +pr: 118774 +summary: Apply default k for knn query eagerly +area: Vector Search +type: bug +issues: [] diff --git a/docs/reference/query-dsl/knn-query.asciidoc b/docs/reference/query-dsl/knn-query.asciidoc index daf9e9499a18..e42bd78d9f14 100644 --- a/docs/reference/query-dsl/knn-query.asciidoc +++ b/docs/reference/query-dsl/knn-query.asciidoc @@ -100,7 +100,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector-builde -- (Optional, integer) The number of nearest neighbors to return from each shard. {es} collects `k` results from each shard, then merges them to find the global top results. -This value must be less than or equal to `num_candidates`. Defaults to `num_candidates`. +This value must be less than or equal to `num_candidates`. Defaults to search request size. -- `num_candidates`:: diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java index 0bcbc19047c8..eaf4d0ad98ef 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java @@ -66,7 +66,7 @@ public class InstrumentationServiceImpl implements InstrumentationService { private static final Type CLASS_TYPE = Type.getType(Class.class); - static MethodKey parseCheckerMethodSignature(String checkerMethodName, Type[] checkerMethodArgumentTypes) { + static ParsedCheckerMethod parseCheckerMethodName(String checkerMethodName) { boolean targetMethodIsStatic; int classNameEndIndex = checkerMethodName.lastIndexOf("$$"); int methodNameStartIndex; @@ -100,9 +100,14 @@ public class InstrumentationServiceImpl implements InstrumentationService { if (targetClassName.isBlank()) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Checker method %s has no class name", checkerMethodName)); } + return new ParsedCheckerMethod(targetClassName, targetMethodName, targetMethodIsStatic, targetMethodIsCtor); + } + + static MethodKey parseCheckerMethodSignature(String checkerMethodName, Type[] checkerMethodArgumentTypes) { + ParsedCheckerMethod checkerMethod = parseCheckerMethodName(checkerMethodName); final List targetParameterTypes; - if (targetMethodIsStatic || targetMethodIsCtor) { + if (checkerMethod.targetMethodIsStatic() || checkerMethod.targetMethodIsCtor()) { if (checkerMethodArgumentTypes.length < 1 || CLASS_TYPE.equals(checkerMethodArgumentTypes[0]) == false) { throw new IllegalArgumentException( String.format( @@ -130,7 +135,13 @@ public class InstrumentationServiceImpl implements InstrumentationService { } targetParameterTypes = Arrays.stream(checkerMethodArgumentTypes).skip(2).map(Type::getInternalName).toList(); } - boolean hasReceiver = (targetMethodIsStatic || targetMethodIsCtor) == false; - return new MethodKey(targetClassName, targetMethodName, targetParameterTypes); + return new MethodKey(checkerMethod.targetClassName(), checkerMethod.targetMethodName(), targetParameterTypes); } + + private record ParsedCheckerMethod( + String targetClassName, + String targetMethodName, + boolean targetMethodIsStatic, + boolean targetMethodIsCtor + ) {} } diff --git a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java index 67d006868b48..66ef1f69c8c3 100644 --- a/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java +++ b/libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java @@ -9,6 +9,13 @@ package org.elasticsearch.entitlement.bridge; +import java.io.InputStream; +import java.io.PrintStream; +import java.io.PrintWriter; +import java.net.ContentHandlerFactory; +import java.net.DatagramSocketImplFactory; +import java.net.FileNameMap; +import java.net.SocketImplFactory; import java.net.URL; import java.net.URLStreamHandlerFactory; import java.util.List; @@ -21,26 +28,42 @@ import javax.net.ssl.SSLSocketFactory; @SuppressWarnings("unused") // Called from instrumentation code inserted by the Entitlements agent public interface EntitlementChecker { + //////////////////// + // // Exit the JVM process + // + void check$java_lang_Runtime$exit(Class callerClass, Runtime runtime, int status); void check$java_lang_Runtime$halt(Class callerClass, Runtime runtime, int status); + //////////////////// + // // ClassLoader ctor + // + void check$java_lang_ClassLoader$(Class callerClass); void check$java_lang_ClassLoader$(Class callerClass, ClassLoader parent); void check$java_lang_ClassLoader$(Class callerClass, String name, ClassLoader parent); + //////////////////// + // // SecureClassLoader ctor + // + void check$java_security_SecureClassLoader$(Class callerClass); void check$java_security_SecureClassLoader$(Class callerClass, ClassLoader parent); void check$java_security_SecureClassLoader$(Class callerClass, String name, ClassLoader parent); + //////////////////// + // // URLClassLoader constructors + // + void check$java_net_URLClassLoader$(Class callerClass, URL[] urls); void check$java_net_URLClassLoader$(Class callerClass, URL[] urls, ClassLoader parent); @@ -51,7 +74,11 @@ public interface EntitlementChecker { void check$java_net_URLClassLoader$(Class callerClass, String name, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory); + //////////////////// + // // "setFactory" methods + // + void check$javax_net_ssl_HttpsURLConnection$setSSLSocketFactory(Class callerClass, HttpsURLConnection conn, SSLSocketFactory sf); void check$javax_net_ssl_HttpsURLConnection$$setDefaultSSLSocketFactory(Class callerClass, SSLSocketFactory sf); @@ -60,9 +87,82 @@ public interface EntitlementChecker { void check$javax_net_ssl_SSLContext$$setDefault(Class callerClass, SSLContext context); + //////////////////// + // // Process creation + // + void check$java_lang_ProcessBuilder$start(Class callerClass, ProcessBuilder that); void check$java_lang_ProcessBuilder$$startPipeline(Class callerClass, List builders); + //////////////////// + // + // JVM-wide state changes + // + + void check$java_lang_System$$setIn(Class callerClass, InputStream in); + + void check$java_lang_System$$setOut(Class callerClass, PrintStream out); + + void check$java_lang_System$$setErr(Class callerClass, PrintStream err); + + void check$java_lang_Runtime$addShutdownHook(Class callerClass, Runtime runtime, Thread hook); + + void check$java_lang_Runtime$removeShutdownHook(Class callerClass, Runtime runtime, Thread hook); + + void check$jdk_tools_jlink_internal_Jlink$(Class callerClass); + + void check$jdk_tools_jlink_internal_Main$$run(Class callerClass, PrintWriter out, PrintWriter err, String... args); + + void check$jdk_vm_ci_services_JVMCIServiceLocator$$getProviders(Class callerClass, Class service); + + void check$jdk_vm_ci_services_Services$$load(Class callerClass, Class service); + + void check$jdk_vm_ci_services_Services$$loadSingle(Class callerClass, Class service, boolean required); + + void check$com_sun_tools_jdi_VirtualMachineManagerImpl$$virtualMachineManager(Class callerClass); + + void check$java_lang_Thread$$setDefaultUncaughtExceptionHandler(Class callerClass, Thread.UncaughtExceptionHandler ueh); + + void check$java_util_spi_LocaleServiceProvider$(Class callerClass); + + void check$java_text_spi_BreakIteratorProvider$(Class callerClass); + + void check$java_text_spi_CollatorProvider$(Class callerClass); + + void check$java_text_spi_DateFormatProvider$(Class callerClass); + + void check$java_text_spi_DateFormatSymbolsProvider$(Class callerClass); + + void check$java_text_spi_DecimalFormatSymbolsProvider$(Class callerClass); + + void check$java_text_spi_NumberFormatProvider$(Class callerClass); + + void check$java_util_spi_CalendarDataProvider$(Class callerClass); + + void check$java_util_spi_CalendarNameProvider$(Class callerClass); + + void check$java_util_spi_CurrencyNameProvider$(Class callerClass); + + void check$java_util_spi_LocaleNameProvider$(Class callerClass); + + void check$java_util_spi_TimeZoneNameProvider$(Class callerClass); + + void check$java_util_logging_LogManager$(Class callerClass); + + void check$java_net_DatagramSocket$$setDatagramSocketImplFactory(Class callerClass, DatagramSocketImplFactory fac); + + void check$java_net_HttpURLConnection$$setFollowRedirects(Class callerClass, boolean set); + + void check$java_net_ServerSocket$$setSocketFactory(Class callerClass, SocketImplFactory fac); + + void check$java_net_Socket$$setSocketImplFactory(Class callerClass, SocketImplFactory fac); + + void check$java_net_URL$$setURLStreamHandlerFactory(Class callerClass, URLStreamHandlerFactory fac); + + void check$java_net_URLConnection$$setFileNameMap(Class callerClass, FileNameMap map); + + void check$java_net_URLConnection$$setContentHandlerFactory(Class callerClass, ContentHandlerFactory fac); + } diff --git a/libs/entitlement/qa/common/src/main/java/module-info.java b/libs/entitlement/qa/common/src/main/java/module-info.java index 2dd37e3174e0..211b7041e97e 100644 --- a/libs/entitlement/qa/common/src/main/java/module-info.java +++ b/libs/entitlement/qa/common/src/main/java/module-info.java @@ -12,5 +12,8 @@ module org.elasticsearch.entitlement.qa.common { requires org.elasticsearch.base; requires org.elasticsearch.logging; + // Modules we'll attempt to use in order to exercise entitlements + requires java.logging; + exports org.elasticsearch.entitlement.qa.common; } diff --git a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/DummyImplementations.java b/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/DummyImplementations.java new file mode 100644 index 000000000000..6dbb684c7151 --- /dev/null +++ b/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/DummyImplementations.java @@ -0,0 +1,334 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.qa.common; + +import java.net.InetAddress; +import java.net.Socket; +import java.security.cert.Certificate; +import java.text.BreakIterator; +import java.text.Collator; +import java.text.DateFormat; +import java.text.DateFormatSymbols; +import java.text.DecimalFormatSymbols; +import java.text.NumberFormat; +import java.text.spi.BreakIteratorProvider; +import java.text.spi.CollatorProvider; +import java.text.spi.DateFormatProvider; +import java.text.spi.DateFormatSymbolsProvider; +import java.text.spi.DecimalFormatSymbolsProvider; +import java.text.spi.NumberFormatProvider; +import java.util.Locale; +import java.util.Map; +import java.util.spi.CalendarDataProvider; +import java.util.spi.CalendarNameProvider; +import java.util.spi.CurrencyNameProvider; +import java.util.spi.LocaleNameProvider; +import java.util.spi.LocaleServiceProvider; +import java.util.spi.TimeZoneNameProvider; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLSocketFactory; + +/** + * A collection of concrete subclasses that we can instantiate but that don't actually work. + *

+ * A bit like Mockito but way more painful. + */ +class DummyImplementations { + + static class DummyLocaleServiceProvider extends LocaleServiceProvider { + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyBreakIteratorProvider extends BreakIteratorProvider { + + @Override + public BreakIterator getWordInstance(Locale locale) { + throw unexpected(); + } + + @Override + public BreakIterator getLineInstance(Locale locale) { + throw unexpected(); + } + + @Override + public BreakIterator getCharacterInstance(Locale locale) { + throw unexpected(); + } + + @Override + public BreakIterator getSentenceInstance(Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyCollatorProvider extends CollatorProvider { + + @Override + public Collator getInstance(Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyDateFormatProvider extends DateFormatProvider { + + @Override + public DateFormat getTimeInstance(int style, Locale locale) { + throw unexpected(); + } + + @Override + public DateFormat getDateInstance(int style, Locale locale) { + throw unexpected(); + } + + @Override + public DateFormat getDateTimeInstance(int dateStyle, int timeStyle, Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyDateFormatSymbolsProvider extends DateFormatSymbolsProvider { + + @Override + public DateFormatSymbols getInstance(Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyDecimalFormatSymbolsProvider extends DecimalFormatSymbolsProvider { + + @Override + public DecimalFormatSymbols getInstance(Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyNumberFormatProvider extends NumberFormatProvider { + + @Override + public NumberFormat getCurrencyInstance(Locale locale) { + throw unexpected(); + } + + @Override + public NumberFormat getIntegerInstance(Locale locale) { + throw unexpected(); + } + + @Override + public NumberFormat getNumberInstance(Locale locale) { + throw unexpected(); + } + + @Override + public NumberFormat getPercentInstance(Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyCalendarDataProvider extends CalendarDataProvider { + + @Override + public int getFirstDayOfWeek(Locale locale) { + throw unexpected(); + } + + @Override + public int getMinimalDaysInFirstWeek(Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyCalendarNameProvider extends CalendarNameProvider { + + @Override + public String getDisplayName(String calendarType, int field, int value, int style, Locale locale) { + throw unexpected(); + } + + @Override + public Map getDisplayNames(String calendarType, int field, int style, Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyCurrencyNameProvider extends CurrencyNameProvider { + + @Override + public String getSymbol(String currencyCode, Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyLocaleNameProvider extends LocaleNameProvider { + + @Override + public String getDisplayLanguage(String languageCode, Locale locale) { + throw unexpected(); + } + + @Override + public String getDisplayCountry(String countryCode, Locale locale) { + throw unexpected(); + } + + @Override + public String getDisplayVariant(String variant, Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyTimeZoneNameProvider extends TimeZoneNameProvider { + + @Override + public String getDisplayName(String ID, boolean daylight, int style, Locale locale) { + throw unexpected(); + } + + @Override + public Locale[] getAvailableLocales() { + throw unexpected(); + } + } + + static class DummyHttpsURLConnection extends HttpsURLConnection { + DummyHttpsURLConnection() { + super(null); + } + + @Override + public void connect() { + throw unexpected(); + } + + @Override + public void disconnect() { + throw unexpected(); + } + + @Override + public boolean usingProxy() { + throw unexpected(); + } + + @Override + public String getCipherSuite() { + throw unexpected(); + } + + @Override + public Certificate[] getLocalCertificates() { + throw unexpected(); + } + + @Override + public Certificate[] getServerCertificates() { + throw unexpected(); + } + } + + static class DummySSLSocketFactory extends SSLSocketFactory { + @Override + public Socket createSocket(String host, int port) { + throw unexpected(); + } + + @Override + public Socket createSocket(String host, int port, InetAddress localHost, int localPort) { + throw unexpected(); + } + + @Override + public Socket createSocket(InetAddress host, int port) { + throw unexpected(); + } + + @Override + public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) { + throw unexpected(); + } + + @Override + public String[] getDefaultCipherSuites() { + throw unexpected(); + } + + @Override + public String[] getSupportedCipherSuites() { + throw unexpected(); + } + + @Override + public Socket createSocket(Socket s, String host, int port, boolean autoClose) { + throw unexpected(); + } + } + + private static RuntimeException unexpected() { + return new IllegalStateException("This method isn't supposed to be called"); + } + +} diff --git a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/RestEntitlementsCheckAction.java b/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/RestEntitlementsCheckAction.java index 4afceedbe3f0..3c12b2f6bc62 100644 --- a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/RestEntitlementsCheckAction.java +++ b/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/RestEntitlementsCheckAction.java @@ -12,6 +12,18 @@ package org.elasticsearch.entitlement.qa.common; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.Strings; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyBreakIteratorProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyCalendarDataProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyCalendarNameProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyCollatorProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyCurrencyNameProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyDateFormatProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyDateFormatSymbolsProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyDecimalFormatSymbolsProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyLocaleNameProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyLocaleServiceProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyNumberFormatProvider; +import org.elasticsearch.entitlement.qa.common.DummyImplementations.DummyTimeZoneNameProvider; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.rest.BaseRestHandler; @@ -21,8 +33,15 @@ import org.elasticsearch.rest.RestStatus; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.DatagramSocket; +import java.net.DatagramSocketImpl; +import java.net.DatagramSocketImplFactory; +import java.net.HttpURLConnection; +import java.net.ServerSocket; +import java.net.Socket; import java.net.URL; import java.net.URLClassLoader; +import java.net.URLConnection; import java.security.NoSuchAlgorithmException; import java.util.List; import java.util.Map; @@ -40,6 +59,7 @@ import static org.elasticsearch.rest.RestRequest.Method.GET; public class RestEntitlementsCheckAction extends BaseRestHandler { private static final Logger logger = LogManager.getLogger(RestEntitlementsCheckAction.class); + public static final Thread NO_OP_SHUTDOWN_HOOK = new Thread(() -> {}, "Shutdown hook for testing"); private final String prefix; record CheckAction(Runnable action, boolean isAlwaysDeniedToPlugins) { @@ -69,11 +89,45 @@ public class RestEntitlementsCheckAction extends BaseRestHandler { entry("set_https_connection_properties", forPlugins(RestEntitlementsCheckAction::setHttpsConnectionProperties)), entry("set_default_ssl_socket_factory", alwaysDenied(RestEntitlementsCheckAction::setDefaultSSLSocketFactory)), entry("set_default_hostname_verifier", alwaysDenied(RestEntitlementsCheckAction::setDefaultHostnameVerifier)), - entry("set_default_ssl_context", alwaysDenied(RestEntitlementsCheckAction::setDefaultSSLContext)) + entry("set_default_ssl_context", alwaysDenied(RestEntitlementsCheckAction::setDefaultSSLContext)), + entry("system_setIn", alwaysDenied(RestEntitlementsCheckAction::system$$setIn)), + entry("system_setOut", alwaysDenied(RestEntitlementsCheckAction::system$$setOut)), + entry("system_setErr", alwaysDenied(RestEntitlementsCheckAction::system$$setErr)), + entry("runtime_addShutdownHook", alwaysDenied(RestEntitlementsCheckAction::runtime$addShutdownHook)), + entry("runtime_removeShutdownHook", alwaysDenied(RestEntitlementsCheckAction::runtime$$removeShutdownHook)), + entry( + "thread_setDefaultUncaughtExceptionHandler", + alwaysDenied(RestEntitlementsCheckAction::thread$$setDefaultUncaughtExceptionHandler) + ), + entry("localeServiceProvider", alwaysDenied(RestEntitlementsCheckAction::localeServiceProvider$)), + entry("breakIteratorProvider", alwaysDenied(RestEntitlementsCheckAction::breakIteratorProvider$)), + entry("collatorProvider", alwaysDenied(RestEntitlementsCheckAction::collatorProvider$)), + entry("dateFormatProvider", alwaysDenied(RestEntitlementsCheckAction::dateFormatProvider$)), + entry("dateFormatSymbolsProvider", alwaysDenied(RestEntitlementsCheckAction::dateFormatSymbolsProvider$)), + entry("decimalFormatSymbolsProvider", alwaysDenied(RestEntitlementsCheckAction::decimalFormatSymbolsProvider$)), + entry("numberFormatProvider", alwaysDenied(RestEntitlementsCheckAction::numberFormatProvider$)), + entry("calendarDataProvider", alwaysDenied(RestEntitlementsCheckAction::calendarDataProvider$)), + entry("calendarNameProvider", alwaysDenied(RestEntitlementsCheckAction::calendarNameProvider$)), + entry("currencyNameProvider", alwaysDenied(RestEntitlementsCheckAction::currencyNameProvider$)), + entry("localeNameProvider", alwaysDenied(RestEntitlementsCheckAction::localeNameProvider$)), + entry("timeZoneNameProvider", alwaysDenied(RestEntitlementsCheckAction::timeZoneNameProvider$)), + entry("logManager", alwaysDenied(RestEntitlementsCheckAction::logManager$)), + + // This group is a bit nasty: if entitlements don't prevent these, then networking is + // irreparably borked for the remainder of the test run. + entry( + "datagramSocket_setDatagramSocketImplFactory", + alwaysDenied(RestEntitlementsCheckAction::datagramSocket$$setDatagramSocketImplFactory) + ), + entry("httpURLConnection_setFollowRedirects", alwaysDenied(RestEntitlementsCheckAction::httpURLConnection$$setFollowRedirects)), + entry("serverSocket_setSocketFactory", alwaysDenied(RestEntitlementsCheckAction::serverSocket$$setSocketFactory)), + entry("socket_setSocketImplFactory", alwaysDenied(RestEntitlementsCheckAction::socket$$setSocketImplFactory)), + entry("url_setURLStreamHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::url$$setURLStreamHandlerFactory)), + entry("urlConnection_setFileNameMap", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setFileNameMap)), + entry("urlConnection_setContentHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setContentHandlerFactory)) ); private static void setDefaultSSLContext() { - logger.info("Calling SSLContext.setDefault"); try { SSLContext.setDefault(SSLContext.getDefault()); } catch (NoSuchAlgorithmException e) { @@ -82,13 +136,11 @@ public class RestEntitlementsCheckAction extends BaseRestHandler { } private static void setDefaultHostnameVerifier() { - logger.info("Calling HttpsURLConnection.setDefaultHostnameVerifier"); HttpsURLConnection.setDefaultHostnameVerifier((hostname, session) -> false); } private static void setDefaultSSLSocketFactory() { - logger.info("Calling HttpsURLConnection.setDefaultSSLSocketFactory"); - HttpsURLConnection.setDefaultSSLSocketFactory(new TestSSLSocketFactory()); + HttpsURLConnection.setDefaultSSLSocketFactory(new DummyImplementations.DummySSLSocketFactory()); } @SuppressForbidden(reason = "Specifically testing Runtime.exit") @@ -126,9 +178,137 @@ public class RestEntitlementsCheckAction extends BaseRestHandler { } private static void setHttpsConnectionProperties() { - logger.info("Calling setSSLSocketFactory"); - var connection = new TestHttpsURLConnection(); - connection.setSSLSocketFactory(new TestSSLSocketFactory()); + new DummyImplementations.DummyHttpsURLConnection().setSSLSocketFactory(new DummyImplementations.DummySSLSocketFactory()); + } + + private static void system$$setIn() { + System.setIn(System.in); + } + + @SuppressForbidden(reason = "This should be a no-op so we don't interfere with system streams") + private static void system$$setOut() { + System.setOut(System.out); + } + + @SuppressForbidden(reason = "This should be a no-op so we don't interfere with system streams") + private static void system$$setErr() { + System.setErr(System.err); + } + + private static void runtime$addShutdownHook() { + Runtime.getRuntime().addShutdownHook(NO_OP_SHUTDOWN_HOOK); + } + + private static void runtime$$removeShutdownHook() { + Runtime.getRuntime().removeShutdownHook(NO_OP_SHUTDOWN_HOOK); + } + + private static void thread$$setDefaultUncaughtExceptionHandler() { + Thread.setDefaultUncaughtExceptionHandler(Thread.getDefaultUncaughtExceptionHandler()); + } + + private static void localeServiceProvider$() { + new DummyLocaleServiceProvider(); + } + + private static void breakIteratorProvider$() { + new DummyBreakIteratorProvider(); + } + + private static void collatorProvider$() { + new DummyCollatorProvider(); + } + + private static void dateFormatProvider$() { + new DummyDateFormatProvider(); + } + + private static void dateFormatSymbolsProvider$() { + new DummyDateFormatSymbolsProvider(); + } + + private static void decimalFormatSymbolsProvider$() { + new DummyDecimalFormatSymbolsProvider(); + } + + private static void numberFormatProvider$() { + new DummyNumberFormatProvider(); + } + + private static void calendarDataProvider$() { + new DummyCalendarDataProvider(); + } + + private static void calendarNameProvider$() { + new DummyCalendarNameProvider(); + } + + private static void currencyNameProvider$() { + new DummyCurrencyNameProvider(); + } + + private static void localeNameProvider$() { + new DummyLocaleNameProvider(); + } + + private static void timeZoneNameProvider$() { + new DummyTimeZoneNameProvider(); + } + + private static void logManager$() { + new java.util.logging.LogManager() { + }; + } + + @SuppressWarnings("deprecation") + @SuppressForbidden(reason = "We're required to prevent calls to this forbidden API") + private static void datagramSocket$$setDatagramSocketImplFactory() { + try { + DatagramSocket.setDatagramSocketImplFactory(new DatagramSocketImplFactory() { + @Override + public DatagramSocketImpl createDatagramSocketImpl() { + throw new IllegalStateException(); + } + }); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + private static void httpURLConnection$$setFollowRedirects() { + HttpURLConnection.setFollowRedirects(HttpURLConnection.getFollowRedirects()); + } + + @SuppressWarnings("deprecation") + @SuppressForbidden(reason = "We're required to prevent calls to this forbidden API") + private static void serverSocket$$setSocketFactory() { + try { + ServerSocket.setSocketFactory(() -> { throw new IllegalStateException(); }); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @SuppressWarnings("deprecation") + @SuppressForbidden(reason = "We're required to prevent calls to this forbidden API") + private static void socket$$setSocketImplFactory() { + try { + Socket.setSocketImplFactory(() -> { throw new IllegalStateException(); }); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + private static void url$$setURLStreamHandlerFactory() { + URL.setURLStreamHandlerFactory(__ -> { throw new IllegalStateException(); }); + } + + private static void urlConnection$$setFileNameMap() { + URLConnection.setFileNameMap(__ -> { throw new IllegalStateException(); }); + } + + private static void urlConnection$$setContentHandlerFactory() { + URLConnection.setContentHandlerFactory(__ -> { throw new IllegalStateException(); }); } public RestEntitlementsCheckAction(String prefix) { diff --git a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/TestHttpsURLConnection.java b/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/TestHttpsURLConnection.java deleted file mode 100644 index 5a96e582db02..000000000000 --- a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/TestHttpsURLConnection.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.qa.common; - -import java.io.IOException; -import java.security.cert.Certificate; - -import javax.net.ssl.HttpsURLConnection; -import javax.net.ssl.SSLPeerUnverifiedException; - -class TestHttpsURLConnection extends HttpsURLConnection { - TestHttpsURLConnection() { - super(null); - } - - @Override - public void connect() throws IOException {} - - @Override - public void disconnect() {} - - @Override - public boolean usingProxy() { - return false; - } - - @Override - public String getCipherSuite() { - return ""; - } - - @Override - public Certificate[] getLocalCertificates() { - return new Certificate[0]; - } - - @Override - public Certificate[] getServerCertificates() throws SSLPeerUnverifiedException { - return new Certificate[0]; - } -} diff --git a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/TestSSLSocketFactory.java b/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/TestSSLSocketFactory.java deleted file mode 100644 index feb19df78017..000000000000 --- a/libs/entitlement/qa/common/src/main/java/org/elasticsearch/entitlement/qa/common/TestSSLSocketFactory.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.qa.common; - -import java.io.IOException; -import java.net.InetAddress; -import java.net.Socket; -import java.net.UnknownHostException; - -import javax.net.ssl.SSLSocketFactory; - -class TestSSLSocketFactory extends SSLSocketFactory { - @Override - public Socket createSocket(String host, int port) throws IOException, UnknownHostException { - return null; - } - - @Override - public Socket createSocket(String host, int port, InetAddress localHost, int localPort) { - return null; - } - - @Override - public Socket createSocket(InetAddress host, int port) throws IOException { - return null; - } - - @Override - public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException { - return null; - } - - @Override - public String[] getDefaultCipherSuites() { - return new String[0]; - } - - @Override - public String[] getSupportedCipherSuites() { - return new String[0]; - } - - @Override - public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { - return null; - } -} diff --git a/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsDeniedIT.java b/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsDeniedIT.java index b17a57512cde..e2e5a3c4c61e 100644 --- a/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsDeniedIT.java +++ b/libs/entitlement/qa/src/javaRestTest/java/org/elasticsearch/entitlement/qa/EntitlementsDeniedIT.java @@ -32,7 +32,7 @@ public class EntitlementsDeniedIT extends ESRestTestCase { .systemProperty("es.entitlements.enabled", "true") .setting("xpack.security.enabled", "false") // Logs in libs/entitlement/qa/build/test-results/javaRestTest/TEST-org.elasticsearch.entitlement.qa.EntitlementsDeniedIT.xml - .setting("logger.org.elasticsearch.entitlement", "TRACE") + // .setting("logger.org.elasticsearch.entitlement", "DEBUG") .build(); @Override diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java index 450786ee57d8..c0a047dc1a45 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java @@ -12,6 +12,13 @@ package org.elasticsearch.entitlement.runtime.api; import org.elasticsearch.entitlement.bridge.EntitlementChecker; import org.elasticsearch.entitlement.runtime.policy.PolicyManager; +import java.io.InputStream; +import java.io.PrintStream; +import java.io.PrintWriter; +import java.net.ContentHandlerFactory; +import java.net.DatagramSocketImplFactory; +import java.net.FileNameMap; +import java.net.SocketImplFactory; import java.net.URL; import java.net.URLStreamHandlerFactory; import java.util.List; @@ -115,6 +122,166 @@ public class ElasticsearchEntitlementChecker implements EntitlementChecker { policyManager.checkStartProcess(callerClass); } + @Override + public void check$java_lang_System$$setIn(Class callerClass, InputStream in) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_lang_System$$setOut(Class callerClass, PrintStream out) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_lang_System$$setErr(Class callerClass, PrintStream err) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_lang_Runtime$addShutdownHook(Class callerClass, Runtime runtime, Thread hook) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_lang_Runtime$removeShutdownHook(Class callerClass, Runtime runtime, Thread hook) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$jdk_tools_jlink_internal_Jlink$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$jdk_tools_jlink_internal_Main$$run(Class callerClass, PrintWriter out, PrintWriter err, String... args) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$jdk_vm_ci_services_JVMCIServiceLocator$$getProviders(Class callerClass, Class service) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$jdk_vm_ci_services_Services$$load(Class callerClass, Class service) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$jdk_vm_ci_services_Services$$loadSingle(Class callerClass, Class service, boolean required) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$com_sun_tools_jdi_VirtualMachineManagerImpl$$virtualMachineManager(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_lang_Thread$$setDefaultUncaughtExceptionHandler(Class callerClass, Thread.UncaughtExceptionHandler ueh) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_spi_LocaleServiceProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_text_spi_BreakIteratorProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_text_spi_CollatorProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_text_spi_DateFormatProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_text_spi_DateFormatSymbolsProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_text_spi_DecimalFormatSymbolsProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_text_spi_NumberFormatProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_spi_CalendarDataProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_spi_CalendarNameProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_spi_CurrencyNameProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_spi_LocaleNameProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_spi_TimeZoneNameProvider$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_util_logging_LogManager$(Class callerClass) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_DatagramSocket$$setDatagramSocketImplFactory(Class callerClass, DatagramSocketImplFactory fac) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_HttpURLConnection$$setFollowRedirects(Class callerClass, boolean set) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_ServerSocket$$setSocketFactory(Class callerClass, SocketImplFactory fac) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_Socket$$setSocketImplFactory(Class callerClass, SocketImplFactory fac) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_URL$$setURLStreamHandlerFactory(Class callerClass, URLStreamHandlerFactory fac) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_URLConnection$$setFileNameMap(Class callerClass, FileNameMap map) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + + @Override + public void check$java_net_URLConnection$$setContentHandlerFactory(Class callerClass, ContentHandlerFactory fac) { + policyManager.checkChangeJVMGlobalState(callerClass); + } + @Override public void check$javax_net_ssl_HttpsURLConnection$setSSLSocketFactory( Class callerClass, @@ -126,16 +293,16 @@ public class ElasticsearchEntitlementChecker implements EntitlementChecker { @Override public void check$javax_net_ssl_HttpsURLConnection$$setDefaultSSLSocketFactory(Class callerClass, SSLSocketFactory sf) { - policyManager.checkSetGlobalHttpsConnectionProperties(callerClass); + policyManager.checkChangeJVMGlobalState(callerClass); } @Override public void check$javax_net_ssl_HttpsURLConnection$$setDefaultHostnameVerifier(Class callerClass, HostnameVerifier hv) { - policyManager.checkSetGlobalHttpsConnectionProperties(callerClass); + policyManager.checkChangeJVMGlobalState(callerClass); } @Override public void check$javax_net_ssl_SSLContext$$setDefault(Class callerClass, SSLContext context) { - policyManager.checkSetGlobalHttpsConnectionProperties(callerClass); + policyManager.checkChangeJVMGlobalState(callerClass); } } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java index 188ce1d747db..9c45f2d42f03 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java @@ -23,12 +23,15 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; import static java.lang.StackWalker.Option.RETAIN_CLASS_REFERENCE; import static java.util.Objects.requireNonNull; +import static java.util.function.Predicate.not; import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.toUnmodifiableMap; public class PolicyManager { private static final Logger logger = LogManager.getLogger(PolicyManager.class); @@ -93,13 +96,13 @@ public class PolicyManager { this.agentEntitlements = agentEntitlements; this.pluginsEntitlements = requireNonNull(pluginPolicies).entrySet() .stream() - .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, e -> buildScopeEntitlementsMap(e.getValue()))); + .collect(toUnmodifiableMap(Map.Entry::getKey, e -> buildScopeEntitlementsMap(e.getValue()))); this.pluginResolver = pluginResolver; this.entitlementsModule = entitlementsModule; } private static Map> buildScopeEntitlementsMap(Policy policy) { - return policy.scopes().stream().collect(Collectors.toUnmodifiableMap(scope -> scope.moduleName(), scope -> scope.entitlements())); + return policy.scopes().stream().collect(toUnmodifiableMap(Scope::moduleName, Scope::entitlements)); } public void checkStartProcess(Class callerClass) { @@ -122,6 +125,26 @@ public class PolicyManager { ); } + /** + * @param operationDescription is only called when the operation is not trivially allowed, meaning the check is about to fail; + * therefore, its performance is not a major concern. + */ + private void neverEntitled(Class callerClass, Supplier operationDescription) { + var requestingModule = requestingClass(callerClass); + if (isTriviallyAllowed(requestingModule)) { + return; + } + + throw new NotEntitledException( + Strings.format( + "Not entitled: caller [%s], module [%s], operation [%s]", + callerClass, + requestingModule.getName(), + operationDescription.get() + ) + ); + } + public void checkExitVM(Class callerClass) { checkEntitlementPresent(callerClass, ExitVMEntitlement.class); } @@ -134,8 +157,23 @@ public class PolicyManager { checkEntitlementPresent(callerClass, SetHttpsConnectionPropertiesEntitlement.class); } - public void checkSetGlobalHttpsConnectionProperties(Class callerClass) { - neverEntitled(callerClass, "set global https connection properties"); + public void checkChangeJVMGlobalState(Class callerClass) { + neverEntitled(callerClass, () -> { + // Look up the check$ method to compose an informative error message. + // This way, we don't need to painstakingly describe every individual global-state change. + Optional checkMethodName = StackWalker.getInstance() + .walk( + frames -> frames.map(StackFrame::getMethodName) + .dropWhile(not(methodName -> methodName.startsWith("check$"))) + .findFirst() + ); + return checkMethodName.map(this::operationDescription).orElse("change JVM global state"); + }); + } + + private String operationDescription(String methodName) { + // TODO: Use a more human-readable description. Perhaps share code with InstrumentationServiceImpl.parseCheckerMethodName + return methodName.substring(methodName.indexOf('$')); } private void checkEntitlementPresent(Class callerClass, Class entitlementClass) { diff --git a/muted-tests.yml b/muted-tests.yml index 2f89d6244a36..dee015015bcc 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -268,6 +268,9 @@ tests: - class: org.elasticsearch.lucene.FullClusterRestartSearchableSnapshotIndexCompatibilityIT method: testSearchableSnapshotUpgrade {p0=9.0.0} issue: https://github.com/elastic/elasticsearch/issues/119632 +- class: org.elasticsearch.search.profile.dfs.DfsProfilerIT + method: testProfileDfs + issue: https://github.com/elastic/elasticsearch/issues/119711 # Examples: # diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index ed8a1f147b4f..a104ec675adc 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -60,5 +60,23 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.skipTest("cat.aliases/10_basic/Deprecated local parameter", "CAT APIs not covered by compatibility policy") task.skipTest("cat.shards/10_basic/Help", "sync_id is removed in 9.0") task.skipTest("search/500_date_range/from, to, include_lower, include_upper deprecated", "deprecated parameters are removed in 9.0") + task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/180_update_dense_vector_type/Test create and update dense vector mapping with bulk indexing", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN query in a bool clause - missing num_candidates", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/Simple knn query", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN search used in nested field - missing num_candidates", "waiting for #118774 backport") + task.skipTest("search.vectors/180_update_dense_vector_type/Test create and update dense vector mapping to int4 with per-doc indexing and flush", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/PRE_FILTER: knn query with internal filter as pre-filter", "waiting for #118774 backport") + task.skipTest("search.vectors/180_update_dense_vector_type/Index, update and merge", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN query with missing num_candidates param - size provided", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/POST_FILTER: knn query with filter from a parent bool query as post-filter", "waiting for #118774 backport") + task.skipTest("search.vectors/120_knn_query_multiple_shards/Aggregations with collected number of docs depends on num_candidates", "waiting for #118774 backport") + task.skipTest("search.vectors/180_update_dense_vector_type/Test create and update dense vector mapping with per-doc indexing and flush", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/PRE_FILTER: knn query with alias filter as pre-filter", "waiting for #118774 backport") + task.skipTest("search.vectors/140_knn_query_with_other_queries/Function score query with knn query", "waiting for #118774 backport") + task.skipTest("search.vectors/130_knn_query_nested_search/nested kNN search inner_hits size > 1", "waiting for #118774 backport") + task.skipTest("search.vectors/110_knn_query_with_filter/PRE_FILTER: pre-filter across multiple aliases", "waiting for #118774 backport") + task.skipTest("search.vectors/160_knn_query_missing_params/kNN search in a dis_max query - missing num_candidates", "waiting for #118774 backport") task.skipTest("search.highlight/30_max_analyzed_offset/Plain highlighter with max_analyzed_offset < 0 should FAIL", "semantics of test has changed") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml index 618951711cff..3d4841a16d82 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/110_knn_query_with_filter.yml @@ -59,7 +59,9 @@ setup: --- "Simple knn query": - + - requires: + cluster_features: "search.vectors.k_param_supported" + reason: 'k param for knn as query is required' - do: search: index: my_index @@ -71,8 +73,9 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - - match: { hits.total.value: 5 } # collector sees num_candidates docs + - match: { hits.total.value: 5 } - length: {hits.hits: 3} - match: { hits.hits.0._id: "1" } - match: { hits.hits.0.fields.my_name.0: v1 } @@ -93,8 +96,9 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - - match: { hits.total.value: 5 } # collector sees num_candidates docs + - match: { hits.total.value: 5 } - length: {hits.hits: 3} - match: { hits.hits.0._id: "2" } - match: { hits.hits.0.fields.my_name.0: v2 } @@ -140,6 +144,7 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - match: { hits.total.value: 5 } - length: { hits.hits: 3 } @@ -184,6 +189,7 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 100 + k: 100 - match: { hits.total.value: 10 } # 5 docs from each alias - length: {hits.hits: 6} @@ -213,6 +219,7 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 filter: term: my_name: v2 @@ -243,9 +250,10 @@ setup: field: my_vector query_vector: [1, 1, 1, 1] num_candidates: 5 + k: 5 - match: { hits.total.value: 2 } - - length: {hits.hits: 2} # knn query returns top 5 docs, but they are post-filtered to 2 docs + - length: {hits.hits: 2} # knn query returns top 3 docs, but they are post-filtered to 2 docs - match: { hits.hits.0._id: "2" } - match: { hits.hits.0.fields.my_name.0: v2 } - match: { hits.hits.1._id: "4" } @@ -271,4 +279,4 @@ setup: my_name: v1 - match: { hits.total.value: 0} - - length: { hits.hits: 0 } # knn query returns top 5 docs, but they are post-filtered to 0 docs + - length: { hits.hits: 0 } # knn query returns top 3 docs, but they are post-filtered to 0 docs diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml index c6f3e187f795..c68565e6629f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/120_knn_query_multiple_shards.yml @@ -166,55 +166,3 @@ setup: - close_to: { hits.hits.2._score: { value: 120, error: 0.00001 } } - close_to: { hits.hits.2.matched_queries.bm25_query: { value: 100.0, error: 0.00001 } } - close_to: { hits.hits.2.matched_queries.knn_query: { value: 20.0, error: 0.00001 } } - ---- -"Aggregations with collected number of docs depends on num_candidates": - - do: - search: - index: my_index - body: - size: 2 - query: - knn: - field: my_vector - query_vector: [1, 1, 1, 1] - num_candidates: 100 # collect up to 100 candidates from each shard - aggs: - my_agg: - terms: - field: my_name - order: - _key: asc - - - length: {hits.hits: 2} - - match: {hits.total.value: 12} - - match: {aggregations.my_agg.buckets.0.key: 'v1'} - - match: {aggregations.my_agg.buckets.1.key: 'v2'} - - match: {aggregations.my_agg.buckets.0.doc_count: 6} - - match: {aggregations.my_agg.buckets.1.doc_count: 6} - - - do: - search: - index: my_index - body: - size: 2 - query: - knn: - field: my_vector - query_vector: [ 1, 1, 1, 1 ] - num_candidates: 3 # collect 3 candidates from each shard - aggs: - my_agg2: - terms: - field: my_name - order: - _key: asc - my_sum_buckets: - sum_bucket: - buckets_path: "my_agg2>_count" - - - length: { hits.hits: 2 } - - match: { hits.total.value: 6 } - - match: { aggregations.my_agg2.buckets.0.key: 'v1' } - - match: { aggregations.my_agg2.buckets.1.key: 'v2' } - - match: { aggregations.my_sum_buckets.value: 6.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml index 79ff3f61742f..bf0714497565 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml @@ -273,6 +273,7 @@ setup: knn: field: nested.vector query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 5 num_candidates: 5 inner_hits: { size: 2, "fields": [ "nested.paragraph_id" ], _source: false } @@ -295,6 +296,7 @@ setup: knn: field: nested.vector query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 5 num_candidates: 5 inner_hits: { size: 2, "fields": [ "nested.paragraph_id" ], _source: false } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml index d52a5daf2234..1e54e497f286 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml @@ -69,6 +69,7 @@ setup: field: my_vector query_vector: [ 1, 1, 1, 1 ] num_candidates: 5 + k: 5 functions: - filter: { match: { my_name: v1 } } weight: 10 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml index 02962e049e26..26c52060dfb2 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/160_knn_query_missing_params.yml @@ -100,8 +100,9 @@ setup: knn: field: vector query_vector: [1, 1, 1] + k: 2 size: 1 - - match: { hits.total: 2 } # due to num_candidates defined as round(1.5 * size), so we only see 2 results + - match: { hits.total: 2 } # k defaults to size - length: { hits.hits: 1 } # one result is only returned though --- @@ -117,6 +118,7 @@ setup: field: vector query_vector: [-1, -1, -1] num_candidates: 1 + k: 1 size: 10 - match: { hits.total: 1 } @@ -137,9 +139,10 @@ setup: - knn: field: vector query_vector: [ 1, 1, 0] + k: 1 size: 1 - - match: { hits.total: 2 } # due to num_candidates defined as round(1.5 * size), so we only see 2 results from cat:A + - match: { hits.total: 1 } - length: { hits.hits: 1 } --- @@ -154,6 +157,7 @@ setup: - knn: field: vector query_vector: [1, 1, 0] + k: 2 - match: category: B tie_breaker: 0.8 @@ -175,6 +179,7 @@ setup: knn: field: nested.vector query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 inner_hits: { size: 1, "fields": [ "nested.paragraph_id" ], _source: false } size: 1 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml index 855daeaa7f16..99943ef2671b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/180_update_dense_vector_type.yml @@ -109,6 +109,7 @@ setup: field: embedding query_vector: [1, 1, 1, 1] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: {hits.hits: 3} @@ -215,6 +216,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -322,6 +324,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -430,6 +433,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 40 + k: 40 - match: { hits.total.value: 40 } - length: { hits.hits: 5 } @@ -499,6 +503,7 @@ setup: field: embedding query_vector: [1, 1, 1, 1] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: {hits.hits: 3} @@ -559,6 +564,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -620,6 +626,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -682,6 +689,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 40 + k: 40 - match: { hits.total.value: 40 } - length: { hits.hits: 5 } @@ -751,6 +759,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: { hits.hits: 3 } @@ -791,6 +800,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: { hits.hits: 3 } @@ -833,6 +843,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -869,6 +880,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -911,6 +923,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -933,6 +946,7 @@ setup: knn: field: embedding query_vector: [ 1, 1, 1, 1 ] + k: 30 num_candidates: 30 - match: { hits.total.value: 30 } @@ -1769,6 +1783,7 @@ setup: field: embedding query_vector: [1, 1, 1, 1] num_candidates: 10 + k: 10 - match: { hits.total.value: 10 } - length: {hits.hits: 3} @@ -1875,6 +1890,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 20 + k: 20 - match: { hits.total.value: 20 } - length: { hits.hits: 3 } @@ -1982,6 +1998,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 30 + k: 30 - match: { hits.total.value: 30 } - length: { hits.hits: 4 } @@ -2090,6 +2107,7 @@ setup: field: embedding query_vector: [ 1, 1, 1, 1 ] num_candidates: 40 + k: 40 - match: { hits.total.value: 40 } - length: { hits.hits: 5 } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 0fbface3793a..8568b6091676 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -271,7 +271,7 @@ final class FetchSearchPhase extends SearchPhase { ) { context.executeNextPhase(this, () -> { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); - context.addReleasable(resp::decRef); + context.addReleasable(resp); return nextPhaseFactory.apply(resp, searchPhaseShardResults); }); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchResponseSections.java b/server/src/main/java/org/elasticsearch/action/search/SearchResponseSections.java index 8c9a42a61e33..9d85348b80d6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchResponseSections.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchResponseSections.java @@ -9,14 +9,12 @@ package org.elasticsearch.action.search; -import org.elasticsearch.core.RefCounted; -import org.elasticsearch.core.SimpleRefCounted; +import org.elasticsearch.core.Releasable; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.profile.SearchProfileResults; import org.elasticsearch.search.profile.SearchProfileShardResult; import org.elasticsearch.search.suggest.Suggest; -import org.elasticsearch.transport.LeakTracker; import java.util.Collections; import java.util.Map; @@ -25,7 +23,7 @@ import java.util.Map; * Holds some sections that a search response is composed of (hits, aggs, suggestions etc.) during some steps of the search response * building. */ -public class SearchResponseSections implements RefCounted { +public class SearchResponseSections implements Releasable { public static final SearchResponseSections EMPTY_WITH_TOTAL_HITS = new SearchResponseSections( SearchHits.EMPTY_WITH_TOTAL_HITS, @@ -53,8 +51,6 @@ public class SearchResponseSections implements RefCounted { protected final Boolean terminatedEarly; protected final int numReducePhases; - private final RefCounted refCounted; - public SearchResponseSections( SearchHits hits, InternalAggregations aggregations, @@ -72,7 +68,6 @@ public class SearchResponseSections implements RefCounted { this.timedOut = timedOut; this.terminatedEarly = terminatedEarly; this.numReducePhases = numReducePhases; - refCounted = hits.getHits().length > 0 ? LeakTracker.wrap(new SimpleRefCounted()) : ALWAYS_REFERENCED; } public final SearchHits hits() { @@ -97,26 +92,7 @@ public class SearchResponseSections implements RefCounted { } @Override - public void incRef() { - refCounted.incRef(); - } - - @Override - public boolean tryIncRef() { - return refCounted.tryIncRef(); - } - - @Override - public boolean decRef() { - if (refCounted.decRef()) { - hits.decRef(); - return true; - } - return false; - } - - @Override - public boolean hasReferences() { - return refCounted.hasReferences(); + public void close() { + hits.decRef(); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java index 60e96a8cce8a..2231f791384f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java @@ -246,8 +246,7 @@ abstract class SearchScrollAsyncAction implements R if (request.scroll() != null) { scrollId = request.scrollId(); } - var sections = SearchPhaseController.merge(true, queryPhase, fetchResults); - try { + try (var sections = SearchPhaseController.merge(true, queryPhase, fetchResults)) { ActionListener.respondAndRelease( listener, new SearchResponse( @@ -262,8 +261,6 @@ abstract class SearchScrollAsyncAction implements R null ) ); - } finally { - sections.decRef(); } } catch (Exception e) { listener.onFailure(new ReduceSearchPhaseException("fetch", "inner finish failed", e, buildShardFailures())); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index a99f21803556..b2b23baacc4d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2019,7 +2019,7 @@ public class DenseVectorFieldMapper extends FieldMapper { public Query createKnnQuery( VectorData queryVector, - Integer k, + int k, int numCands, Float numCandsFactor, Query filter, @@ -2052,7 +2052,7 @@ public class DenseVectorFieldMapper extends FieldMapper { private Query createKnnBitQuery( byte[] queryVector, - Integer k, + int k, int numCands, Query filter, Float similarityThreshold, @@ -2074,7 +2074,7 @@ public class DenseVectorFieldMapper extends FieldMapper { private Query createKnnByteQuery( byte[] queryVector, - Integer k, + int k, int numCands, Query filter, Float similarityThreshold, @@ -2101,7 +2101,7 @@ public class DenseVectorFieldMapper extends FieldMapper { private Query createKnnFloatQuery( float[] queryVector, - Integer k, + int k, int numCands, Float numCandsFactor, Query filter, diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 1548a62d2b3e..c34314149079 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -37,6 +37,8 @@ public final class SearchCapabilities { private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support"; /** Fixed the math in {@code moving_fn}'s {@code linearWeightedAvg}. */ private static final String MOVING_FN_RIGHT_MATH = "moving_fn_right_math"; + /** knn query where k defaults to the request size. */ + private static final String K_DEFAULT_TO_SIZE = "k_default_to_size"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; @@ -57,6 +59,7 @@ public final class SearchCapabilities { capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE); capabilities.add(MOVING_FN_RIGHT_MATH); + capabilities.add(K_DEFAULT_TO_SIZE); if (Build.current().isSnapshot()) { capabilities.add(KQL_QUERY_SUPPORTED); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index b18ce2dff65c..9b9718efcf52 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -465,7 +465,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, rescoreVectorBuilder, similarity).boost(boost) + return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index 81b00f132959..12573d5ad496 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -256,7 +256,7 @@ public class KnnSearchRequestParser { if (numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, null); + return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, null, null); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 88f6312fa7e6..a65757cc2587 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -495,15 +495,16 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder mSearchResponses = new ArrayList<>(numInnerHits); for (int innerHitNum = 0; innerHitNum < numInnerHits; innerHitNum++) { - var sections = new SearchResponseSections(collapsedHits.get(innerHitNum), null, null, false, null, null, 1); - try { + try ( + var sections = new SearchResponseSections(collapsedHits.get(innerHitNum), null, null, false, null, null, 1) + ) { mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); } mSearchResponses.add(new MultiSearchResponse.Item(mockSearchPhaseContext.searchResponse.get(), null)); // transferring ownership to the multi-search response so no need to release here @@ -121,11 +120,8 @@ public class ExpandSearchPhaseTests extends ESTestCase { ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { @Override public void run() { - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { + try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); } } }); @@ -215,11 +211,8 @@ public class ExpandSearchPhaseTests extends ESTestCase { ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { @Override public void run() { - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { + try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); } } }); @@ -254,11 +247,8 @@ public class ExpandSearchPhaseTests extends ESTestCase { ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { @Override public void run() { - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { + try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { mockSearchPhaseContext.sendSearchResponse(sections, null); - } finally { - sections.decRef(); } } }); diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchLookupFieldsPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchLookupFieldsPhaseTests.java index 1d2daf0cd660..5c508dce61fc 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchLookupFieldsPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchLookupFieldsPhaseTests.java @@ -47,12 +47,10 @@ public class FetchLookupFieldsPhaseTests extends ESTestCase { searchHits[i] = SearchHitTests.createTestItem(randomBoolean(), randomBoolean()); } SearchHits hits = new SearchHits(searchHits, new TotalHits(numHits, TotalHits.Relation.EQUAL_TO), 1.0f); - var sections = new SearchResponseSections(hits, null, null, false, null, null, 1); - try { + try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { FetchLookupFieldsPhase phase = new FetchLookupFieldsPhase(searchPhaseContext, sections, null); phase.run(); } finally { - sections.decRef(); hits.decRef(); } searchPhaseContext.assertNoFailure(); @@ -189,12 +187,10 @@ public class FetchLookupFieldsPhaseTests extends ESTestCase { new TotalHits(2, TotalHits.Relation.EQUAL_TO), 1.0f ); - var sections = new SearchResponseSections(searchHits, null, null, false, null, null, 1); - try { + try (var sections = new SearchResponseSections(searchHits, null, null, false, null, null, 1)) { FetchLookupFieldsPhase phase = new FetchLookupFieldsPhase(searchPhaseContext, sections, null); phase.run(); } finally { - sections.decRef(); searchHits.decRef(); } assertTrue(requestSent.get()); diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 353188af8be3..568186e0cae5 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -418,7 +418,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase { float[] queryVector = randomVector(); assertResponse( client().prepareSearch("index1", "index2") - .setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null)) + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, 5, null, null)) .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 9a507977c012..bf8148608736 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -292,8 +292,7 @@ public class SearchPhaseControllerTests extends ESTestCase { reducedQueryPhase.suggest(), profile ); - final SearchResponseSections mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults); - try { + try (SearchResponseSections mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults)) { if (trackTotalHits == SearchContext.TRACK_TOTAL_HITS_DISABLED) { assertNull(mergedResponse.hits.getTotalHits()); } else { @@ -346,7 +345,6 @@ public class SearchPhaseControllerTests extends ESTestCase { assertThat(mergedResponse.profile(), is(anEmptyMap())); } } finally { - mergedResponse.decRef(); fetchResults.asList().forEach(TransportMessage::decRef); } } finally { @@ -410,8 +408,7 @@ public class SearchPhaseControllerTests extends ESTestCase { reducedQueryPhase.suggest(), false ); - SearchResponseSections mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults); - try { + try (SearchResponseSections mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults)) { if (trackTotalHits == SearchContext.TRACK_TOTAL_HITS_DISABLED) { assertNull(mergedResponse.hits.getTotalHits()); } else { @@ -427,7 +424,6 @@ public class SearchPhaseControllerTests extends ESTestCase { assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.sortedTopDocs().scoreDocs().length)); assertThat(mergedResponse.profile(), is(anEmptyMap())); } finally { - mergedResponse.decRef(); fetchResults.asList().forEach(TransportMessage::decRef); } } finally { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 375712ee6086..244d53940331 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -47,7 +47,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT; -import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -82,7 +81,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, - Integer k, + int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity @@ -138,8 +137,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa @Override protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; - Integer k = randomBoolean() ? null : randomIntBetween(1, 100); - int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000); + int k = randomIntBetween(1, 100); + int numCands = randomIntBetween(k + 20, 1000); KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( fieldName, k, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index f6c2e754cec6..26066389c63f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -20,7 +20,7 @@ public class KnnByteVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilde @Override protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, - Integer k, + int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index 6f67e4be29a0..70d29ab525ef 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -20,7 +20,7 @@ public class KnnFloatVectorQueryBuilderTests extends AbstractKnnVectorQueryBuild @Override KnnVectorQueryBuilder createKnnVectorQueryBuilder( String fieldName, - Integer k, + int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index a39438af5b72..108dc60e2ee3 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -224,9 +224,9 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase result = runEsql(builder); assertMap( result, @@ -70,7 +76,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { ); // filter includes only test1. Columns from test2 are filtered out, as well (not only rows)! - builder = timestampFilter("gte", "2024-01-01").query("FROM test*"); + builder = timestampFilter("gte", "2024-01-01").query(from("test*")); assertMap( runEsql(builder), matchesMap().entry( @@ -83,7 +89,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { // filter excludes both indices (no rows); the first analysis step fails because there are no columns, a second attempt succeeds // after eliminating the index filter. All columns are returned. - builder = timestampFilter("gte", "2025-01-01").query("FROM test*"); + builder = timestampFilter("gte", "2025-01-01").query(from("test*")); assertMap( runEsql(builder), matchesMap().entry( @@ -103,7 +109,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { indexTimestampData(docsTest2, "test2", "2023-11-26", "id2"); // filter includes only test1. Columns and rows of test2 are filtered out - RestEsqlTestCase.RequestObjectBuilder builder = existsFilter("id1").query("FROM test*"); + RestEsqlTestCase.RequestObjectBuilder builder = existsFilter("id1").query(from("test*")); Map result = runEsql(builder); assertMap( result, @@ -116,7 +122,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { ); // filter includes only test1. Columns from test2 are filtered out, as well (not only rows)! - builder = existsFilter("id1").query("FROM test* METADATA _index | KEEP _index, id*"); + builder = existsFilter("id1").query(from("test*") + " METADATA _index | KEEP _index, id*"); result = runEsql(builder); assertMap( result, @@ -129,7 +135,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { @SuppressWarnings("unchecked") var values = (List>) result.get("values"); for (List row : values) { - assertThat(row.get(0), equalTo("test1")); + assertThat(row.get(0), oneOf("test1", "remote_cluster:test1")); assertThat(row.get(1), instanceOf(Integer.class)); } } @@ -142,7 +148,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { // test2 is explicitly used in a query with "SORT id2" even if the index filter should discard test2 RestEsqlTestCase.RequestObjectBuilder builder = existsFilter("id1").query( - "FROM test* METADATA _index | SORT id2 | KEEP _index, id*" + from("test*") + " METADATA _index | SORT id2 | KEEP _index, id*" ); Map result = runEsql(builder); assertMap( @@ -157,7 +163,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { @SuppressWarnings("unchecked") var values = (List>) result.get("values"); for (List row : values) { - assertThat(row.get(0), equalTo("test1")); + assertThat(row.get(0), oneOf("test1", "remote_cluster:test1")); assertThat(row.get(1), instanceOf(Integer.class)); assertThat(row.get(2), nullValue()); } @@ -172,59 +178,59 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { // idx field name is explicitly used, though it doesn't exist in any of the indices. First test - without filter ResponseException e = expectThrows( ResponseException.class, - () -> runEsql(requestObjectBuilder().query("FROM test* | WHERE idx == 123")) + () -> runEsql(requestObjectBuilder().query(from("test*") + " | WHERE idx == 123")) ); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("verification_exception")); assertThat(e.getMessage(), containsString("Found 1 problem")); - assertThat(e.getMessage(), containsString("line 1:20: Unknown column [idx]")); + assertThat(e.getMessage(), containsString("Unknown column [idx]")); - e = expectThrows(ResponseException.class, () -> runEsql(requestObjectBuilder().query("FROM test1 | WHERE idx == 123"))); + e = expectThrows(ResponseException.class, () -> runEsql(requestObjectBuilder().query(from("test1") + " | WHERE idx == 123"))); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("verification_exception")); assertThat(e.getMessage(), containsString("Found 1 problem")); - assertThat(e.getMessage(), containsString("line 1:20: Unknown column [idx]")); + assertThat(e.getMessage(), containsString("Unknown column [idx]")); e = expectThrows( ResponseException.class, - () -> runEsql(timestampFilter("gte", "2020-01-01").query("FROM test* | WHERE idx == 123")) + () -> runEsql(timestampFilter("gte", "2020-01-01").query(from("test*") + " | WHERE idx == 123")) ); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("Found 1 problem")); - assertThat(e.getMessage(), containsString("line 1:20: Unknown column [idx]")); + assertThat(e.getMessage(), containsString("Unknown column [idx]")); e = expectThrows( ResponseException.class, - () -> runEsql(timestampFilter("gte", "2020-01-01").query("FROM test2 | WHERE idx == 123")) + () -> runEsql(timestampFilter("gte", "2020-01-01").query(from("test2") + " | WHERE idx == 123")) ); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("Found 1 problem")); - assertThat(e.getMessage(), containsString("line 1:20: Unknown column [idx]")); + assertThat(e.getMessage(), containsString("Unknown column [idx]")); } public void testIndicesDontExist() throws IOException { int docsTest1 = 0; // we are interested only in the created index, not necessarily that it has data indexTimestampData(docsTest1, "test1", "2024-11-26", "id1"); - ResponseException e = expectThrows(ResponseException.class, () -> runEsql(timestampFilter("gte", "2020-01-01").query("FROM foo"))); + ResponseException e = expectThrows(ResponseException.class, () -> runEsql(timestampFilter("gte", "2020-01-01").query(from("foo")))); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("verification_exception")); - assertThat(e.getMessage(), containsString("Unknown index [foo]")); + assertThat(e.getMessage(), anyOf(containsString("Unknown index [foo]"), containsString("Unknown index [remote_cluster:foo]"))); - e = expectThrows(ResponseException.class, () -> runEsql(timestampFilter("gte", "2020-01-01").query("FROM foo*"))); + e = expectThrows(ResponseException.class, () -> runEsql(timestampFilter("gte", "2020-01-01").query(from("foo*")))); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("verification_exception")); - assertThat(e.getMessage(), containsString("Unknown index [foo*]")); + assertThat(e.getMessage(), anyOf(containsString("Unknown index [foo*]"), containsString("Unknown index [remote_cluster:foo*]"))); - e = expectThrows(ResponseException.class, () -> runEsql(timestampFilter("gte", "2020-01-01").query("FROM foo,test1"))); + e = expectThrows(ResponseException.class, () -> runEsql(timestampFilter("gte", "2020-01-01").query(from("foo", "test1")))); assertEquals(404, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("index_not_found_exception")); - assertThat(e.getMessage(), containsString("no such index [foo]")); + assertThat(e.getMessage(), anyOf(containsString("no such index [foo]"), containsString("no such index [remote_cluster:foo]"))); if (EsqlCapabilities.Cap.JOIN_LOOKUP_V10.isEnabled()) { e = expectThrows( ResponseException.class, - () -> runEsql(timestampFilter("gte", "2020-01-01").query("FROM test1 | LOOKUP JOIN foo ON id1")) + () -> runEsql(timestampFilter("gte", "2020-01-01").query(from("test1") + " | LOOKUP JOIN foo ON id1")) ); assertEquals(400, e.getResponse().getStatusLine().getStatusCode()); assertThat(e.getMessage(), containsString("verification_exception")); @@ -251,6 +257,11 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { } protected void indexTimestampData(int docs, String indexName, String date, String differentiatorFieldName) throws IOException { + indexTimestampDataForClient(client(), docs, indexName, date, differentiatorFieldName); + } + + protected void indexTimestampDataForClient(RestClient client, int docs, String indexName, String date, String differentiatorFieldName) + throws IOException { Request createIndex = new Request("PUT", indexName); createIndex.setJsonEntity(""" { @@ -273,7 +284,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { } } }""".replace("%differentiator_field_name%", differentiatorFieldName)); - Response response = client().performRequest(createIndex); + Response response = client.performRequest(createIndex); assertThat( entityToMap(response.getEntity(), XContentType.JSON), matchesMap().entry("shards_acknowledged", true).entry("index", indexName).entry("acknowledged", true) @@ -291,7 +302,7 @@ public abstract class RequestIndexFilteringTestCase extends ESRestTestCase { bulk.addParameter("refresh", "true"); bulk.addParameter("filter_path", "errors"); bulk.setJsonEntity(b.toString()); - response = client().performRequest(bulk); + response = client.performRequest(bulk); Assert.assertEquals("{\"errors\":false}", EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8)); } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java index 5333f1b905b2..9fd9c4c294c9 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java @@ -19,6 +19,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.DataStreamAction; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.persistent.AllocatedPersistentTask; @@ -61,6 +62,7 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec ) { ReindexDataStreamTaskParams params = taskInProgress.getParams(); return new ReindexDataStreamTask( + clusterService, params.startTime(), params.totalIndices(), params.totalIndicesToBeUpgraded(), @@ -74,7 +76,12 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec } @Override - protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTaskParams params, PersistentTaskState state) { + protected void nodeOperation( + AllocatedPersistentTask task, + ReindexDataStreamTaskParams params, + PersistentTaskState persistentTaskState + ) { + ReindexDataStreamPersistentTaskState state = (ReindexDataStreamPersistentTaskState) persistentTaskState; String sourceDataStream = params.getSourceDataStream(); TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId()); GetDataStreamAction.Request request = new GetDataStreamAction.Request(TimeValue.MAX_VALUE, new String[] { sourceDataStream }); @@ -93,22 +100,43 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec RolloverAction.INSTANCE, rolloverRequest, ActionListener.wrap( - rolloverResponse -> reindexIndices(dataStream, reindexDataStreamTask, reindexClient, sourceDataStream, taskId), - e -> completeFailedPersistentTask(reindexDataStreamTask, e) + rolloverResponse -> reindexIndices( + dataStream, + dataStream.getIndices().size() + 1, + reindexDataStreamTask, + params, + state, + reindexClient, + sourceDataStream, + taskId + ), + e -> completeFailedPersistentTask(reindexDataStreamTask, state, e) ) ); } else { - reindexIndices(dataStream, reindexDataStreamTask, reindexClient, sourceDataStream, taskId); + reindexIndices( + dataStream, + dataStream.getIndices().size(), + reindexDataStreamTask, + params, + state, + reindexClient, + sourceDataStream, + taskId + ); } } else { - completeFailedPersistentTask(reindexDataStreamTask, new ElasticsearchException("data stream does not exist")); + completeFailedPersistentTask(reindexDataStreamTask, state, new ElasticsearchException("data stream does not exist")); } - }, exception -> completeFailedPersistentTask(reindexDataStreamTask, exception))); + }, exception -> completeFailedPersistentTask(reindexDataStreamTask, state, exception))); } private void reindexIndices( DataStream dataStream, + int totalIndicesInDataStream, ReindexDataStreamTask reindexDataStreamTask, + ReindexDataStreamTaskParams params, + ReindexDataStreamPersistentTaskState state, ExecuteWithHeadersClient reindexClient, String sourceDataStream, TaskId parentTaskId @@ -117,11 +145,28 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec List indicesToBeReindexed = indices.stream() .filter(getReindexRequiredPredicate(clusterService.state().metadata().getProject())) .toList(); + final ReindexDataStreamPersistentTaskState updatedState; + if (params.totalIndices() != totalIndicesInDataStream + || params.totalIndicesToBeUpgraded() != indicesToBeReindexed.size() + || (state != null + && (state.totalIndices() != null + && state.totalIndicesToBeUpgraded() != null + && (state.totalIndices() != totalIndicesInDataStream + || state.totalIndicesToBeUpgraded() != indicesToBeReindexed.size())))) { + updatedState = new ReindexDataStreamPersistentTaskState( + totalIndicesInDataStream, + indicesToBeReindexed.size(), + state == null ? null : state.completionTime() + ); + reindexDataStreamTask.updatePersistentTaskState(updatedState, ActionListener.noop()); + } else { + updatedState = state; + } reindexDataStreamTask.setPendingIndicesCount(indicesToBeReindexed.size()); // The CountDownActionListener is 1 more than the number of indices so that the count is not 0 if we have no indices CountDownActionListener listener = new CountDownActionListener(indicesToBeReindexed.size() + 1, ActionListener.wrap(response1 -> { - completeSuccessfulPersistentTask(reindexDataStreamTask); - }, exception -> { completeFailedPersistentTask(reindexDataStreamTask, exception); })); + completeSuccessfulPersistentTask(reindexDataStreamTask, updatedState); + }, exception -> { completeFailedPersistentTask(reindexDataStreamTask, updatedState, exception); })); List indicesRemaining = Collections.synchronizedList(new ArrayList<>(indicesToBeReindexed)); final int maxConcurrentIndices = 1; for (int i = 0; i < maxConcurrentIndices; i++) { @@ -193,15 +238,25 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec }); } - private void completeSuccessfulPersistentTask(ReindexDataStreamTask persistentTask) { - persistentTask.allReindexesCompleted(threadPool, getTimeToLive(persistentTask)); + private void completeSuccessfulPersistentTask( + ReindexDataStreamTask persistentTask, + @Nullable ReindexDataStreamPersistentTaskState state + ) { + persistentTask.allReindexesCompleted(threadPool, updateCompletionTimeAndGetTimeToLive(persistentTask, state)); } - private void completeFailedPersistentTask(ReindexDataStreamTask persistentTask, Exception e) { - persistentTask.taskFailed(threadPool, getTimeToLive(persistentTask), e); + private void completeFailedPersistentTask( + ReindexDataStreamTask persistentTask, + @Nullable ReindexDataStreamPersistentTaskState state, + Exception e + ) { + persistentTask.taskFailed(threadPool, updateCompletionTimeAndGetTimeToLive(persistentTask, state), e); } - private TimeValue getTimeToLive(ReindexDataStreamTask reindexDataStreamTask) { + private TimeValue updateCompletionTimeAndGetTimeToLive( + ReindexDataStreamTask reindexDataStreamTask, + @Nullable ReindexDataStreamPersistentTaskState state + ) { PersistentTasksCustomMetadata persistentTasksCustomMetadata = clusterService.state() .getMetadata() .getProject() @@ -212,16 +267,23 @@ public class ReindexDataStreamPersistentTaskExecutor extends PersistentTasksExec if (persistentTask == null) { return TimeValue.timeValueMillis(0); } - PersistentTaskState state = persistentTask.getState(); final long completionTime; if (state == null) { completionTime = threadPool.absoluteTimeInMillis(); reindexDataStreamTask.updatePersistentTaskState( - new ReindexDataStreamPersistentTaskState(completionTime), + new ReindexDataStreamPersistentTaskState(null, null, completionTime), ActionListener.noop() ); } else { - completionTime = ((ReindexDataStreamPersistentTaskState) state).completionTime(); + if (state.completionTime() == null) { + completionTime = threadPool.absoluteTimeInMillis(); + reindexDataStreamTask.updatePersistentTaskState( + new ReindexDataStreamPersistentTaskState(state.totalIndices(), state.totalIndicesToBeUpgraded(), completionTime), + ActionListener.noop() + ); + } else { + completionTime = state.completionTime(); + } } return TimeValue.timeValueMillis(TASK_KEEP_ALIVE_TIME.millis() - (threadPool.absoluteTimeInMillis() - completionTime)); } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskState.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskState.java index 130a8f7ce372..8ab22674082e 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskState.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskState.java @@ -9,6 +9,7 @@ package org.elasticsearch.xpack.migrate.task; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.ConstructingObjectParser; @@ -18,22 +19,31 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public record ReindexDataStreamPersistentTaskState( + @Nullable Integer totalIndices, + @Nullable Integer totalIndicesToBeUpgraded, + @Nullable Long completionTime +) implements Task.Status, PersistentTaskState { -public record ReindexDataStreamPersistentTaskState(long completionTime) implements Task.Status, PersistentTaskState { public static final String NAME = ReindexDataStreamTask.TASK_NAME; + private static final String TOTAL_INDICES_FIELD = "total_indices_in_data_stream"; + private static final String TOTAL_INDICES_REQUIRING_UPGRADE_FIELD = "total_indices_requiring_upgrade"; private static final String COMPLETION_TIME_FIELD = "completion_time"; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, true, - args -> new ReindexDataStreamPersistentTaskState((long) args[0]) + args -> new ReindexDataStreamPersistentTaskState((Integer) args[0], (Integer) args[1], (Long) args[2]) ); static { - PARSER.declareLong(constructorArg(), new ParseField(COMPLETION_TIME_FIELD)); + PARSER.declareInt(optionalConstructorArg(), new ParseField(TOTAL_INDICES_FIELD)); + PARSER.declareInt(optionalConstructorArg(), new ParseField(TOTAL_INDICES_REQUIRING_UPGRADE_FIELD)); + PARSER.declareLong(optionalConstructorArg(), new ParseField(COMPLETION_TIME_FIELD)); } public ReindexDataStreamPersistentTaskState(StreamInput in) throws IOException { - this(in.readLong()); + this(in.readOptionalInt(), in.readOptionalInt(), in.readOptionalLong()); } @Override @@ -43,13 +53,23 @@ public record ReindexDataStreamPersistentTaskState(long completionTime) implemen @Override public void writeTo(StreamOutput out) throws IOException { - out.writeLong(completionTime); + out.writeOptionalInt(totalIndices); + out.writeOptionalInt(totalIndicesToBeUpgraded); + out.writeOptionalLong(completionTime); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(COMPLETION_TIME_FIELD, completionTime); + if (totalIndices != null) { + builder.field(TOTAL_INDICES_FIELD, totalIndices); + } + if (totalIndicesToBeUpgraded != null) { + builder.field(TOTAL_INDICES_REQUIRING_UPGRADE_FIELD, totalIndicesToBeUpgraded); + } + if (completionTime != null) { + builder.field(COMPLETION_TIME_FIELD, completionTime); + } builder.endObject(); return builder; } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java index 7a2b759dfd17..8f98a8d4d1da 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.migrate.task; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; @@ -24,9 +26,10 @@ import java.util.concurrent.atomic.AtomicInteger; public class ReindexDataStreamTask extends AllocatedPersistentTask { public static final String TASK_NAME = "reindex-data-stream"; + private final ClusterService clusterService; private final long persistentTaskStartTime; - private final int totalIndices; - private final int totalIndicesToBeUpgraded; + private final int initialTotalIndices; + private final int initialTotalIndicesToBeUpgraded; private volatile boolean complete = false; private volatile Exception exception; private final Set inProgress = Collections.synchronizedSet(new HashSet<>()); @@ -36,9 +39,10 @@ public class ReindexDataStreamTask extends AllocatedPersistentTask { @SuppressWarnings("this-escape") public ReindexDataStreamTask( + ClusterService clusterService, long persistentTaskStartTime, - int totalIndices, - int totalIndicesToBeUpgraded, + int initialTotalIndices, + int initialTotalIndicesToBeUpgraded, long id, String type, String action, @@ -47,9 +51,10 @@ public class ReindexDataStreamTask extends AllocatedPersistentTask { Map headers ) { super(id, type, action, description, parentTask, headers); + this.clusterService = clusterService; this.persistentTaskStartTime = persistentTaskStartTime; - this.totalIndices = totalIndices; - this.totalIndicesToBeUpgraded = totalIndicesToBeUpgraded; + this.initialTotalIndices = initialTotalIndices; + this.initialTotalIndicesToBeUpgraded = initialTotalIndicesToBeUpgraded; this.completeTask = new RunOnce(() -> { if (exception == null) { markAsCompleted(); @@ -61,6 +66,20 @@ public class ReindexDataStreamTask extends AllocatedPersistentTask { @Override public ReindexDataStreamStatus getStatus() { + PersistentTasksCustomMetadata persistentTasksCustomMetadata = clusterService.state() + .getMetadata() + .getProject() + .custom(PersistentTasksCustomMetadata.TYPE); + int totalIndices = initialTotalIndices; + int totalIndicesToBeUpgraded = initialTotalIndicesToBeUpgraded; + PersistentTasksCustomMetadata.PersistentTask persistentTask = persistentTasksCustomMetadata.getTask(getPersistentTaskId()); + if (persistentTask != null) { + ReindexDataStreamPersistentTaskState state = (ReindexDataStreamPersistentTaskState) persistentTask.getState(); + if (state != null && state.totalIndices() != null && state.totalIndicesToBeUpgraded() != null) { + totalIndices = Math.toIntExact(state.totalIndices()); + totalIndicesToBeUpgraded = Math.toIntExact(state.totalIndicesToBeUpgraded()); + } + } return new ReindexDataStreamStatus( persistentTaskStartTime, totalIndices, diff --git a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskStateTests.java b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskStateTests.java index a35cd6e5fa47..c2dee83a91df 100644 --- a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskStateTests.java +++ b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskStateTests.java @@ -26,11 +26,32 @@ public class ReindexDataStreamPersistentTaskStateTests extends AbstractXContentS @Override protected ReindexDataStreamPersistentTaskState createTestInstance() { - return new ReindexDataStreamPersistentTaskState(randomNegativeLong()); + return new ReindexDataStreamPersistentTaskState( + randomBoolean() ? null : randomNonNegativeInt(), + randomBoolean() ? null : randomNonNegativeInt(), + randomBoolean() ? null : randomNonNegativeLong() + ); } @Override protected ReindexDataStreamPersistentTaskState mutateInstance(ReindexDataStreamPersistentTaskState instance) throws IOException { - return new ReindexDataStreamPersistentTaskState(instance.completionTime() + 1); + return switch (randomInt(2)) { + case 0 -> new ReindexDataStreamPersistentTaskState( + instance.totalIndices() == null ? randomNonNegativeInt() : instance.totalIndices() + 1, + instance.totalIndicesToBeUpgraded(), + instance.completionTime() + ); + case 1 -> new ReindexDataStreamPersistentTaskState( + instance.totalIndices(), + instance.totalIndicesToBeUpgraded() == null ? randomNonNegativeInt() : instance.totalIndicesToBeUpgraded() + 1, + instance.completionTime() + ); + case 2 -> new ReindexDataStreamPersistentTaskState( + instance.totalIndices(), + instance.totalIndicesToBeUpgraded(), + instance.completionTime() == null ? randomNonNegativeLong() : instance.completionTime() + 1 + ); + default -> throw new IllegalArgumentException("Should never get here"); + }; } } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/DataStreamsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/DataStreamsUpgradeIT.java index 3ce92ea29ec1..4c4f915a8fe1 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/DataStreamsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/DataStreamsUpgradeIT.java @@ -255,7 +255,11 @@ public class DataStreamsUpgradeIT extends AbstractUpgradeTestCase { } } - private void upgradeDataStream(String dataStreamName, int numRollovers) throws Exception { + private void upgradeDataStream(String dataStreamName, int numRolloversOnOldCluster) throws Exception { + final int explicitRolloverOnNewClusterCount = randomIntBetween(0, 2); + for (int i = 0; i < explicitRolloverOnNewClusterCount; i++) { + rollover(dataStreamName); + } Request reindexRequest = new Request("POST", "/_migration/reindex"); reindexRequest.setJsonEntity(Strings.format(""" { @@ -276,18 +280,27 @@ public class DataStreamsUpgradeIT extends AbstractUpgradeTestCase { ); assertOK(statusResponse); assertThat(statusResponseMap.get("complete"), equalTo(true)); - /* - * total_indices_in_data_stream is determined at the beginning of the reindex, and does not take into account the write - * index being rolled over - */ - assertThat(statusResponseMap.get("total_indices_in_data_stream"), equalTo(numRollovers + 1)); + // The number of rollovers that will have happened when we call reindex: + final int rolloversPerformedByReindex = explicitRolloverOnNewClusterCount == 0 ? 1 : 0; + final int originalWriteIndex = 1; + assertThat( + statusResponseMap.get("total_indices_in_data_stream"), + equalTo(originalWriteIndex + numRolloversOnOldCluster + explicitRolloverOnNewClusterCount + rolloversPerformedByReindex) + ); if (isOriginalClusterSameMajorVersionAsCurrent()) { // If the original cluster was the same as this one, we don't want any indices reindexed: assertThat(statusResponseMap.get("total_indices_requiring_upgrade"), equalTo(0)); assertThat(statusResponseMap.get("successes"), equalTo(0)); } else { - assertThat(statusResponseMap.get("total_indices_requiring_upgrade"), equalTo(numRollovers + 1)); - assertThat(statusResponseMap.get("successes"), equalTo(numRollovers + 1)); + /* + * total_indices_requiring_upgrade is made up of: (the original write index) + numRolloversOnOldCluster. The number of + * rollovers on the upgraded cluster is irrelevant since those will not be reindexed. + */ + assertThat( + statusResponseMap.get("total_indices_requiring_upgrade"), + equalTo(originalWriteIndex + numRolloversOnOldCluster) + ); + assertThat(statusResponseMap.get("successes"), equalTo(numRolloversOnOldCluster + 1)); } }, 60, TimeUnit.SECONDS); Request cancelRequest = new Request("POST", "_migration/reindex/" + dataStreamName + "/_cancel");