Simplify instrumenter and tests (#118493)

This commit simplifies the entitlements instrumentation service and
instrumenter a bit. It especially removes some repetition in the
instrumenter tests.
This commit is contained in:
Ryan Ernst 2024-12-13 15:23:37 -08:00 committed by GitHub
parent f900ae61bb
commit b456e16c7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 222 additions and 327 deletions

View file

@ -29,15 +29,13 @@ import java.util.Map;
public class InstrumentationServiceImpl implements InstrumentationService {
@Override
public Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
return InstrumenterImpl.create(checkMethods);
public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods) {
return InstrumenterImpl.create(clazz, methods);
}
@Override
public Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
IOException {
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
var checkerClass = Class.forName(entitlementCheckerClassName);
var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {

View file

@ -58,30 +58,14 @@ public class InstrumenterImpl implements Instrumenter {
this.checkMethods = checkMethods;
}
static String getCheckerClassName() {
int javaVersion = Runtime.version().feature();
final String classNamePrefix;
if (javaVersion >= 23) {
classNamePrefix = "Java23";
} else {
classNamePrefix = "";
}
return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
}
public static InstrumenterImpl create(Map<MethodKey, CheckMethod> checkMethods) {
String checkerClass = getCheckerClassName();
String handleClass = checkerClass + "Handle";
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
public static InstrumenterImpl create(Class<?> checkerClass, Map<MethodKey, CheckMethod> checkMethods) {
Type checkerClassType = Type.getType(checkerClass);
String handleClass = checkerClassType.getInternalName() + "Handle";
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(checkerClassType);
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
}
public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {
ClassFileInfo initial = getClassFileInfo(clazz);
return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes()));
}
public static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
String internalName = Type.getInternalName(clazz);
String fileName = "/" + internalName + ".class";
byte[] originalBytecodes;
@ -306,5 +290,5 @@ public class InstrumenterImpl implements Instrumenter {
mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false);
}
public record ClassFileInfo(String fileName, byte[] bytecodes) {}
record ClassFileInfo(String fileName, byte[] bytecodes) {}
}

View file

@ -51,8 +51,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
}
public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
public void testInstrumentationTargetLookup() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);
assertThat(checkMethods, aMapWithSize(3));
assertThat(
@ -116,8 +116,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
);
}
public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
public void testInstrumentationTargetLookupWithOverloads() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);
assertThat(checkMethods, aMapWithSize(2));
assertThat(
@ -148,8 +148,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
);
}
public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
public void testInstrumentationTargetLookupWithCtors() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);
assertThat(checkMethods, aMapWithSize(2));
assertThat(

View file

@ -12,31 +12,64 @@ package org.elasticsearch.entitlement.instrumentation.impl;
import org.elasticsearch.common.Strings;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.objectweb.asm.Type;
import java.io.IOException;
import java.lang.reflect.AccessFlag;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.hamcrest.Matchers.equalTo;
/**
* This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check
* some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
* This tests {@link InstrumenterImpl} can instrument various method signatures
* (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
*/
@ESTestCase.WithoutSecurityManager
public class InstrumenterTests extends ESTestCase {
private static final Logger logger = LogManager.getLogger(InstrumenterTests.class);
static class TestLoader extends ClassLoader {
final byte[] testClassBytes;
final Class<?> testClass;
TestLoader(String testClassName, byte[] testClassBytes) {
super(InstrumenterTests.class.getClassLoader());
this.testClassBytes = testClassBytes;
this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length);
}
Method getSameMethod(Method method) {
try {
return testClass.getMethod(method.getName(), method.getParameterTypes());
} catch (NoSuchMethodException e) {
throw new AssertionError(e);
}
}
Constructor<?> getSameConstructor(Constructor<?> ctor) {
try {
return testClass.getConstructor(ctor.getParameterTypes());
} catch (NoSuchMethodException e) {
throw new AssertionError(e);
}
}
}
/**
* Contains all the virtual methods from {@link TestClassToInstrument},
* allowing this test to call them on the dynamically loaded instrumented class.
@ -80,13 +113,15 @@ public class InstrumenterTests extends ESTestCase {
public interface MockEntitlementChecker {
void checkSomeStaticMethod(Class<?> clazz, int arg);
void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
void checkSomeStaticMethodOverload(Class<?> clazz, int arg, String anotherArg);
void checkAnotherStaticMethod(Class<?> clazz, int arg);
void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
void checkCtor(Class<?> clazz);
void checkCtor(Class<?> clazz, int arg);
void checkCtorOverload(Class<?> clazz, int arg);
}
public static class TestEntitlementCheckerHolder {
@ -105,6 +140,7 @@ public class InstrumenterTests extends ESTestCase {
volatile boolean isActive;
int checkSomeStaticMethodIntCallCount = 0;
int checkAnotherStaticMethodIntCallCount = 0;
int checkSomeStaticMethodIntStringCallCount = 0;
int checkSomeInstanceMethodCallCount = 0;
@ -120,28 +156,33 @@ public class InstrumenterTests extends ESTestCase {
@Override
public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
checkSomeStaticMethodIntCallCount++;
assertSame(TestMethodUtils.class, callerClass);
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
throwIfActive();
}
@Override
public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
public void checkSomeStaticMethodOverload(Class<?> callerClass, int arg, String anotherArg) {
checkSomeStaticMethodIntStringCallCount++;
assertSame(TestMethodUtils.class, callerClass);
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
assertEquals("abc", anotherArg);
throwIfActive();
}
@Override
public void checkAnotherStaticMethod(Class<?> callerClass, int arg) {
checkAnotherStaticMethodIntCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
throwIfActive();
}
@Override
public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
checkSomeInstanceMethodCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertThat(
that.getClass().getName(),
startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument")
);
assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName()));
assertEquals(123, arg);
assertEquals("def", anotherArg);
throwIfActive();
@ -155,7 +196,7 @@ public class InstrumenterTests extends ESTestCase {
}
@Override
public void checkCtor(Class<?> callerClass, int arg) {
public void checkCtorOverload(Class<?> callerClass, int arg) {
checkCtorIntCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
@ -163,206 +204,83 @@ public class InstrumenterTests extends ESTestCase {
}
}
public void testClassIsInstrumented() throws Exception {
var classToInstrument = TestClassToInstrument.class;
@Before
public void resetInstance() {
TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker();
}
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
checkMethod
);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = false;
public void testStaticMethod() throws Exception {
Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)));
// Before checking is active, nothing should throw
callStaticMethod(newClass, "someStaticMethod", 123);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
assertStaticMethod(loader, targetMethod, 123);
// After checking is activated, everything should throw
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertStaticMethodThrows(loader, targetMethod, 123);
}
public void testClassIsNotInstrumentedTwice() throws Exception {
var classToInstrument = TestClassToInstrument.class;
public void testNotInstrumentedTwice() throws Exception {
Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod));
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
checkMethod
);
var instrumenter = createInstrumenter(checkMethods);
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
var internalClassName = Type.getInternalName(classToInstrument);
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
var loader1 = instrumentTestClass(instrumenter);
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes);
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode);
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW_NEW",
instrumentedTwiceBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertStaticMethodThrows(loader2, targetMethod, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
}
public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
var classToInstrument = TestClassToInstrument.class;
public void testMultipleMethods() throws Exception {
Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class);
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
checkMethod,
methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)),
checkMethod
);
var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2));
var loader = instrumentTestClass(instrumenter);
var instrumenter = createInstrumenter(checkMethods);
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
var internalClassName = Type.getInternalName(classToInstrument);
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW_NEW",
instrumentedTwiceBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertStaticMethodThrows(loader, targetMethod1, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123));
assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
assertStaticMethodThrows(loader, targetMethod2, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount);
}
public void testInstrumenterWorksWithOverloads() throws Exception {
var classToInstrument = TestClassToInstrument.class;
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
public void testStaticMethodOverload() throws Exception {
Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class);
var instrumenter = createInstrumenter(
Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2)
);
var loader = instrumentTestClass(instrumenter);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0;
// After checking is activated, everything should throw
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
assertStaticMethodThrows(loader, targetMethod1, 123);
assertStaticMethodThrows(loader, targetMethod2, 123, "abc");
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount);
}
public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
var classToInstrument = TestClassToInstrument.class;
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
public void testInstanceMethodOverload() throws Exception {
Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class);
var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod));
var loader = instrumentTestClass(instrumenter);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0;
Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance());
// This overload is not instrumented, so it will not throw
testTargetClass.someMethod(123);
assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount);
}
public void testInstrumenterWorksWithConstructors() throws Exception {
var classToInstrument = TestClassToInstrument.class;
Map<MethodKey, CheckMethod> checkMethods = Map.of(
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
assertThat(ex.getCause(), instanceOf(TestException.class));
var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
assertThat(ex2.getCause(), instanceOf(TestException.class));
public void testConstructors() throws Exception {
Constructor<?> ctor1 = TestClassToInstrument.class.getConstructor();
Constructor<?> ctor2 = TestClassToInstrument.class.getConstructor(int.class);
var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2)));
assertCtorThrows(loader, ctor1);
assertCtorThrows(loader, ctor2, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount);
}
@ -373,11 +291,107 @@ public class InstrumenterTests extends ESTestCase {
* MethodKey and instrumentationMethod with slightly different signatures (using the common interface
* Testable) which is not what would happen when it's run by the agent.
*/
private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
private static InstrumenterImpl createInstrumenter(Map<String, Executable> methods) throws NoSuchMethodException {
Map<MethodKey, CheckMethod> checkMethods = new HashMap<>();
for (var entry : methods.entrySet()) {
checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue()));
}
String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class);
String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class);
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods);
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
}
private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException {
var clazz = TestClassToInstrument.class;
ClassFileInfo initial = getClassFileInfo(clazz);
byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes());
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
return new TestLoader(clazz.getName(), newBytecode);
}
private static MethodKey getMethodKey(Executable method) {
logger.info("method key: {}", method.getName());
String methodName = method instanceof Constructor<?> ? "<init>" : method.getName();
return new MethodKey(
Type.getInternalName(method.getDeclaringClass()),
methodName,
Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList()
);
}
private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException {
Set<AccessFlag> flags = targetMethod.accessFlags();
boolean isInstance = flags.contains(AccessFlag.STATIC) == false && targetMethod instanceof Method;
int extraArgs = 1; // caller class
if (isInstance) {
++extraArgs;
}
Class<?>[] targetParameterTypes = targetMethod.getParameterTypes();
Class<?>[] checkParameterTypes = new Class<?>[targetParameterTypes.length + extraArgs];
checkParameterTypes[0] = Class.class;
if (isInstance) {
checkParameterTypes[1] = Testable.class;
}
System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length);
var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes);
return new CheckMethod(
Type.getInternalName(MockEntitlementChecker.class),
checkMethod.getName(),
Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList()
);
}
private static void unwrapInvocationException(InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {
// Sometimes we're expecting this one!
throw n;
} else {
throw new AssertionError(cause);
}
}
/**
* Calling a static method of a dynamically loaded class is significantly more cumbersome
* than calling a virtual method.
*/
static void callStaticMethod(Method method, Object... args) {
try {
method.invoke(null, args);
} catch (InvocationTargetException e) {
unwrapInvocationException(e);
} catch (IllegalAccessException e) {
throw new AssertionError(e);
}
}
private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) {
Method testMethod = loader.getSameMethod(method);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
expectThrows(TestException.class, () -> callStaticMethod(testMethod, args));
}
private void assertStaticMethod(TestLoader loader, Method method, Object... args) {
Method testMethod = loader.getSameMethod(method);
TestEntitlementCheckerHolder.checkerInstance.isActive = false;
callStaticMethod(testMethod, args);
}
private void assertCtorThrows(TestLoader loader, Constructor<?> ctor, Object... args) {
Constructor<?> testCtor = loader.getSameConstructor(ctor);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
expectThrows(TestException.class, () -> {
try {
testCtor.newInstance(args);
} catch (InvocationTargetException e) {
unwrapInvocationException(e);
} catch (IllegalAccessException | InstantiationException e) {
throw new AssertionError(e);
}
});
}
}

View file

@ -1,20 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.instrumentation.impl;
class TestLoader extends ClassLoader {
TestLoader(ClassLoader parent) {
super(parent);
}
public Class<?> defineClassFromBytes(String name, byte[] bytes) {
return defineClass(name, bytes, 0, bytes.length);
}
}

View file

@ -1,81 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.instrumentation.impl;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.objectweb.asm.Type;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
class TestMethodUtils {
/**
* @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
*/
static MethodKey methodKeyForTarget(Method targetMethod) {
Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
return new MethodKey(
Type.getInternalName(targetMethod.getDeclaringClass()),
targetMethod.getName(),
Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
);
}
static MethodKey methodKeyForConstructor(Class<?> classToInstrument, List<String> params) {
return new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", params);
}
static CheckMethod getCheckMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) throws NoSuchMethodException {
var method = clazz.getMethod(methodName, parameterTypes);
return new CheckMethod(
Type.getInternalName(clazz),
method.getName(),
Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList()
);
}
/**
* Calling a static method of a dynamically loaded class is significantly more cumbersome
* than calling a virtual method.
*/
static void callStaticMethod(Class<?> c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException {
try {
c.getMethod(methodName, int.class).invoke(null, arg);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {
// Sometimes we're expecting this one!
throw n;
} else {
throw new AssertionError(cause);
}
}
}
static void callStaticMethod(Class<?> c, String methodName, int arg1, String arg2) throws NoSuchMethodException,
IllegalAccessException {
try {
c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {
// Sometimes we're expecting this one!
throw n;
} else {
throw new AssertionError(cause);
}
}
}
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap;
import org.elasticsearch.entitlement.bridge.EntitlementChecker;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
import org.elasticsearch.entitlement.instrumentation.Instrumenter;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.elasticsearch.entitlement.instrumentation.Transformer;
import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker;
@ -64,13 +65,12 @@ public class EntitlementInitialization {
public static void initialize(Instrumentation inst) throws Exception {
manager = initChecker();
Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument(
"org.elasticsearch.entitlement.bridge.EntitlementChecker"
);
Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethods(EntitlementChecker.class);
var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());
inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true);
Instrumenter instrumenter = INSTRUMENTER_FACTORY.newInstrumenter(EntitlementChecker.class, checkMethods);
inst.addTransformer(new Transformer(instrumenter, classesToTransform), true);
// TODO: should we limit this array somehow?
var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new);
inst.retransformClasses(classesToRetransform);

View file

@ -16,7 +16,7 @@ import java.util.Map;
* The SPI service entry point for instrumentation.
*/
public interface InstrumentationService {
Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods);
Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods);
Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException;
Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws IOException;
}