[Entitlements] Add support for instrumenting constructors (#117332)

This commit is contained in:
Lorenzo Dematté 2024-11-27 11:31:02 +01:00 committed by GitHub
parent d7737e7306
commit 9799d0082b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 281 additions and 23 deletions

View file

@ -91,15 +91,18 @@ public class InstrumentationServiceImpl implements InstrumentationService {
String.format(
Locale.ROOT,
"Checker method %s has incorrect name format. "
+ "It should be either check$$methodName (instance) or check$package_ClassName$methodName (static)",
+ "It should be either check$$methodName (instance), check$package_ClassName$methodName (static) or "
+ "check$package_ClassName$ (ctor)",
checkerMethodName
)
);
}
// No "className" (check$$methodName) -> method is static, and we'll get the class from the actual typed argument
// No "className" (check$$methodName) -> method is instance, and we'll get the class from the actual typed argument
final boolean targetMethodIsStatic = classNameStartIndex + 1 != classNameEndIndex;
final String targetMethodName = checkerMethodName.substring(classNameEndIndex + 1);
// No "methodName" (check$package_ClassName$) -> method is ctor
final boolean targetMethodIsCtor = classNameEndIndex + 1 == checkerMethodName.length();
final String targetMethodName = targetMethodIsCtor ? "<init>" : checkerMethodName.substring(classNameEndIndex + 1);
final String targetClassName;
final List<String> targetParameterTypes;

View file

@ -154,11 +154,12 @@ public class InstrumenterImpl implements Instrumenter {
var mv = super.visitMethod(access, name, descriptor, signature, exceptions);
if (isAnnotationPresent == false) {
boolean isStatic = (access & ACC_STATIC) != 0;
boolean isCtor = "<init>".equals(name);
var key = new MethodKey(className, name, Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList());
var instrumentationMethod = instrumentationMethods.get(key);
if (instrumentationMethod != null) {
// LOGGER.debug("Will instrument method {}", key);
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod);
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, isCtor, descriptor, instrumentationMethod);
} else {
// LOGGER.trace("Will not instrument method {}", key);
}
@ -187,6 +188,7 @@ public class InstrumenterImpl implements Instrumenter {
class EntitlementMethodVisitor extends MethodVisitor {
private final boolean instrumentedMethodIsStatic;
private final boolean instrumentedMethodIsCtor;
private final String instrumentedMethodDescriptor;
private final CheckerMethod instrumentationMethod;
private boolean hasCallerSensitiveAnnotation = false;
@ -195,11 +197,13 @@ public class InstrumenterImpl implements Instrumenter {
int api,
MethodVisitor methodVisitor,
boolean instrumentedMethodIsStatic,
boolean instrumentedMethodIsCtor,
String instrumentedMethodDescriptor,
CheckerMethod instrumentationMethod
) {
super(api, methodVisitor);
this.instrumentedMethodIsStatic = instrumentedMethodIsStatic;
this.instrumentedMethodIsCtor = instrumentedMethodIsCtor;
this.instrumentedMethodDescriptor = instrumentedMethodDescriptor;
this.instrumentationMethod = instrumentationMethod;
}
@ -260,14 +264,15 @@ public class InstrumenterImpl implements Instrumenter {
private void forwardIncomingArguments() {
int localVarIndex = 0;
if (instrumentedMethodIsStatic == false) {
if (instrumentedMethodIsCtor) {
localVarIndex++;
} else if (instrumentedMethodIsStatic == false) {
mv.visitVarInsn(Opcodes.ALOAD, localVarIndex++);
}
for (Type type : Type.getArgumentTypes(instrumentedMethodDescriptor)) {
mv.visitVarInsn(type.getOpcode(Opcodes.ILOAD), localVarIndex);
localVarIndex += type.getSize();
}
}
private void invokeInstrumentationMethod() {

View file

@ -45,6 +45,12 @@ public class InstrumentationServiceImplTests extends ESTestCase {
void check$org_example_TestTargetClass$staticMethodWithOverload(Class<?> clazz, int x, String y);
}
interface TestCheckerCtors {
void check$org_example_TestTargetClass$(Class<?> clazz);
void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
}
public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
@ -142,6 +148,38 @@ public class InstrumentationServiceImplTests extends ESTestCase {
);
}
public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
assertThat(methodsMap, aMapWithSize(2));
assertThat(
methodsMap,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of("I", "java/lang/String"))),
equalTo(
new CheckerMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
"check$org_example_TestTargetClass$",
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;")
)
)
)
);
assertThat(
methodsMap,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of())),
equalTo(
new CheckerMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
"check$org_example_TestTargetClass$",
List.of("Ljava/lang/Class;")
)
)
)
);
}
public void testParseCheckerMethodSignatureStaticMethod() {
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
"check$org_example_TestClass$staticMethod",
@ -169,6 +207,24 @@ public class InstrumentationServiceImplTests extends ESTestCase {
assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass$InnerClass", "staticMethod", List.of())));
}
public void testParseCheckerMethodSignatureCtor() {
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
"check$org_example_TestClass$",
new Type[] { Type.getType(Class.class) }
);
assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass", "<init>", List.of())));
}
public void testParseCheckerMethodSignatureCtorWithArgs() {
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
"check$org_example_TestClass$",
new Type[] { Type.getType(Class.class), Type.getType("I"), Type.getType(String.class) }
);
assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass", "<init>", List.of("I", "java/lang/String"))));
}
public void testParseCheckerMethodSignatureIncorrectName() {
var exception = assertThrows(
IllegalArgumentException.class,

View file

@ -23,12 +23,15 @@ import org.objectweb.asm.Type;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLStreamHandlerFactory;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.objectweb.asm.Opcodes.INVOKESTATIC;
@ -72,6 +75,11 @@ public class InstrumenterTests extends ESTestCase {
* They must not throw {@link TestException}.
*/
public static class ClassToInstrument implements Testable {
public ClassToInstrument() {}
public ClassToInstrument(int arg) {}
public static void systemExit(int status) {
assertEquals(123, status);
}
@ -91,12 +99,20 @@ public class InstrumenterTests extends ESTestCase {
static final class TestException extends RuntimeException {}
/**
* Interface to test specific, "synthetic" cases (e.g. overloaded methods, overloaded constructors, etc.) that
* may be not present/may be difficult to find or not clear in the production EntitlementChecker interface
*/
public interface MockEntitlementChecker extends EntitlementChecker {
void checkSomeStaticMethod(Class<?> clazz, int arg);
void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
void checkCtor(Class<?> clazz);
void checkCtor(Class<?> clazz, int arg);
}
/**
@ -118,6 +134,9 @@ public class InstrumenterTests extends ESTestCase {
int checkSomeStaticMethodIntStringCallCount = 0;
int checkSomeInstanceMethodCallCount = 0;
int checkCtorCallCount = 0;
int checkCtorIntCallCount = 0;
@Override
public void check$java_lang_System$exit(Class<?> callerClass, int status) {
checkSystemExitCallCount++;
@ -126,6 +145,27 @@ public class InstrumenterTests extends ESTestCase {
throwIfActive();
}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls) {}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent) {}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory) {}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent) {}
@Override
public void check$java_net_URLClassLoader$(
Class<?> callerClass,
String name,
URL[] urls,
ClassLoader parent,
URLStreamHandlerFactory factory
) {}
private void throwIfActive() {
if (isActive) {
throw new TestException();
@ -161,6 +201,21 @@ public class InstrumenterTests extends ESTestCase {
assertEquals("def", anotherArg);
throwIfActive();
}
@Override
public void checkCtor(Class<?> callerClass) {
checkCtorCallCount++;
assertSame(InstrumenterTests.class, callerClass);
throwIfActive();
}
@Override
public void checkCtor(Class<?> callerClass, int arg) {
checkCtorIntCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
throwIfActive();
}
}
public void testClassIsInstrumented() throws Exception {
@ -225,7 +280,7 @@ public class InstrumenterTests extends ESTestCase {
getTestEntitlementChecker().checkSystemExitCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(1));
assertEquals(1, getTestEntitlementChecker().checkSystemExitCallCount);
}
public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
@ -259,10 +314,10 @@ public class InstrumenterTests extends ESTestCase {
getTestEntitlementChecker().checkSystemExitCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(1));
assertEquals(1, getTestEntitlementChecker().checkSystemExitCallCount);
assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherSystemExit", 123));
assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(2));
assertEquals(2, getTestEntitlementChecker().checkSystemExitCallCount);
}
public void testInstrumenterWorksWithOverloads() throws Exception {
@ -294,8 +349,8 @@ public class InstrumenterTests extends ESTestCase {
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
assertThat(getTestEntitlementChecker().checkSomeStaticMethodIntCallCount, is(1));
assertThat(getTestEntitlementChecker().checkSomeStaticMethodIntStringCallCount, is(1));
assertEquals(1, getTestEntitlementChecker().checkSomeStaticMethodIntCallCount);
assertEquals(1, getTestEntitlementChecker().checkSomeStaticMethodIntStringCallCount);
}
public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
@ -327,7 +382,41 @@ public class InstrumenterTests extends ESTestCase {
testTargetClass.someMethod(123);
assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
assertThat(getTestEntitlementChecker().checkSomeInstanceMethodCallCount, is(1));
assertEquals(1, getTestEntitlementChecker().checkSomeInstanceMethodCallCount);
}
public void testInstrumenterWorksWithConstructors() throws Exception {
var classToInstrument = ClassToInstrument.class;
Map<MethodKey, CheckerMethod> methods = Map.of(
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
getCheckerMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
getCheckerMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
);
var instrumenter = createInstrumenter(methods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
getTestEntitlementChecker().isActive = true;
var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
assertThat(ex.getCause(), instanceOf(TestException.class));
var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
assertThat(ex2.getCause(), instanceOf(TestException.class));
assertEquals(1, getTestEntitlementChecker().checkCtorCallCount);
assertEquals(1, getTestEntitlementChecker().checkCtorIntCallCount);
}
/** This test doesn't replace classToInstrument in-place but instead loads a separate

View file

@ -9,6 +9,20 @@
package org.elasticsearch.entitlement.bridge;
import java.net.URL;
import java.net.URLStreamHandlerFactory;
public interface EntitlementChecker {
void check$java_lang_System$exit(Class<?> callerClass, int status);
// URLClassLoader ctor
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls);
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent);
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory);
void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent);
void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory);
}

View file

@ -169,10 +169,6 @@ public class EntitlementInitialization {
}
}
private static String internalName(Class<?> c) {
return c.getName().replace('.', '/');
}
private static final InstrumentationService INSTRUMENTER_FACTORY = new ProviderLocator<>(
"entitlement",
InstrumentationService.class,

View file

@ -13,6 +13,9 @@ import org.elasticsearch.entitlement.bridge.EntitlementChecker;
import org.elasticsearch.entitlement.runtime.policy.FlagEntitlementType;
import org.elasticsearch.entitlement.runtime.policy.PolicyManager;
import java.net.URL;
import java.net.URLStreamHandlerFactory;
/**
* Implementation of the {@link EntitlementChecker} interface, providing additional
* API methods for managing the checks.
@ -29,4 +32,35 @@ public class ElasticsearchEntitlementChecker implements EntitlementChecker {
public void check$java_lang_System$exit(Class<?> callerClass, int status) {
policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.SYSTEM_EXIT);
}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls) {
policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER);
}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent) {
policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER);
}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory) {
policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER);
}
@Override
public void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent) {
policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER);
}
@Override
public void check$java_net_URLClassLoader$(
Class<?> callerClass,
String name,
URL[] urls,
ClassLoader parent,
URLStreamHandlerFactory factory
) {
policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER);
}
}

View file

@ -10,5 +10,6 @@
package org.elasticsearch.entitlement.runtime.policy;
public enum FlagEntitlementType {
SYSTEM_EXIT;
SYSTEM_EXIT,
CREATE_CLASSLOADER;
}

View file

@ -66,7 +66,7 @@ public class PolicyManager {
// TODO: this will be checked using policies
if (requestingModule.isNamed()
&& requestingModule.getName().equals("org.elasticsearch.server")
&& type == FlagEntitlementType.SYSTEM_EXIT) {
&& (type == FlagEntitlementType.SYSTEM_EXIT || type == FlagEntitlementType.CREATE_CLASSLOADER)) {
logger.debug("Allowed: caller [{}] in module [{}] has entitlement [{}]", callerClass, requestingModule.getName(), type);
return;
}

View file

@ -39,4 +39,11 @@ public class EntitlementsIT extends ESRestTestCase {
);
assertThat(exception.getMessage(), containsString("not_entitled_exception"));
}
public void testCheckCreateURLClassLoader() {
var exception = expectThrows(IOException.class, () -> {
client().performRequest(new Request("GET", "/_entitlement/_check_create_url_classloader"));
});
assertThat(exception.getMessage(), containsString("not_entitled_exception"));
}
}

View file

@ -22,7 +22,6 @@ import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.function.Supplier;
@ -42,6 +41,6 @@ public class EntitlementsCheckPlugin extends Plugin implements ActionPlugin {
final Supplier<DiscoveryNodes> nodesInCluster,
Predicate<NodeFeature> clusterSupportsFeature
) {
return Collections.singletonList(new RestEntitlementsCheckSystemExitAction());
return List.of(new RestEntitlementsCheckSystemExitAction(), new RestEntitlementsCheckClassLoaderAction());
}
}

View file

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.test.entitlements;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.GET;
public class RestEntitlementsCheckClassLoaderAction extends BaseRestHandler {
private static final Logger logger = LogManager.getLogger(RestEntitlementsCheckClassLoaderAction.class);
RestEntitlementsCheckClassLoaderAction() {}
@Override
public List<Route> routes() {
return List.of(new Route(GET, "/_entitlement/_check_create_url_classloader"));
}
@Override
public String getName() {
return "check_classloader_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
logger.info("RestEntitlementsCheckClassLoaderAction rest handler [{}]", request.path());
if (request.path().equals("/_entitlement/_check_create_url_classloader")) {
return channel -> {
logger.info("Calling new URLClassLoader");
try (var classLoader = new URLClassLoader("test", new URL[0], this.getClass().getClassLoader())) {
logger.info("Created URLClassLoader [{}]", classLoader.getName());
}
};
}
throw new UnsupportedOperationException();
}
}

View file

@ -210,7 +210,7 @@ class Elasticsearch {
bootstrap.setPluginsLoader(pluginsLoader);
if (Boolean.parseBoolean(System.getProperty("es.entitlements.enabled"))) {
logger.info("Bootstrapping Entitlements");
LogManager.getLogger(Elasticsearch.class).info("Bootstrapping Entitlements");
List<Tuple<Path, Boolean>> pluginData = new ArrayList<>();
Set<PluginBundle> moduleBundles = PluginsUtils.getModuleBundles(nodeEnv.modulesFile());
@ -225,7 +225,7 @@ class Elasticsearch {
EntitlementBootstrap.bootstrap(pluginData, callerClass -> null);
} else {
// install SM after natives, shutdown hooks, etc.
logger.info("Bootstrapping java SecurityManager");
LogManager.getLogger(Elasticsearch.class).info("Bootstrapping java SecurityManager");
org.elasticsearch.bootstrap.Security.configure(
nodeEnv,
SECURITY_FILTER_BAD_DEFAULTS_SETTING.get(args.nodeSettings()),