mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 15:17:30 -04:00
Simplify instrumenter and tests (#118493)
This commit simplifies the entitlements instrumentation service and instrumenter a bit. It especially removes some repetition in the instrumenter tests.
This commit is contained in:
parent
f900ae61bb
commit
b456e16c7d
8 changed files with 222 additions and 327 deletions
|
@ -29,15 +29,13 @@ import java.util.Map;
|
|||
public class InstrumentationServiceImpl implements InstrumentationService {
|
||||
|
||||
@Override
|
||||
public Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
|
||||
return InstrumenterImpl.create(checkMethods);
|
||||
public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods) {
|
||||
return InstrumenterImpl.create(clazz, methods);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
|
||||
IOException {
|
||||
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
|
||||
var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
|
||||
var checkerClass = Class.forName(entitlementCheckerClassName);
|
||||
var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
|
||||
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
|
||||
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
|
||||
|
|
|
@ -58,30 +58,14 @@ public class InstrumenterImpl implements Instrumenter {
|
|||
this.checkMethods = checkMethods;
|
||||
}
|
||||
|
||||
static String getCheckerClassName() {
|
||||
int javaVersion = Runtime.version().feature();
|
||||
final String classNamePrefix;
|
||||
if (javaVersion >= 23) {
|
||||
classNamePrefix = "Java23";
|
||||
} else {
|
||||
classNamePrefix = "";
|
||||
}
|
||||
return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
|
||||
}
|
||||
|
||||
public static InstrumenterImpl create(Map<MethodKey, CheckMethod> checkMethods) {
|
||||
String checkerClass = getCheckerClassName();
|
||||
String handleClass = checkerClass + "Handle";
|
||||
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
|
||||
public static InstrumenterImpl create(Class<?> checkerClass, Map<MethodKey, CheckMethod> checkMethods) {
|
||||
Type checkerClassType = Type.getType(checkerClass);
|
||||
String handleClass = checkerClassType.getInternalName() + "Handle";
|
||||
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(checkerClassType);
|
||||
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
|
||||
}
|
||||
|
||||
public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {
|
||||
ClassFileInfo initial = getClassFileInfo(clazz);
|
||||
return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes()));
|
||||
}
|
||||
|
||||
public static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
|
||||
static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
|
||||
String internalName = Type.getInternalName(clazz);
|
||||
String fileName = "/" + internalName + ".class";
|
||||
byte[] originalBytecodes;
|
||||
|
@ -306,5 +290,5 @@ public class InstrumenterImpl implements Instrumenter {
|
|||
mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false);
|
||||
}
|
||||
|
||||
public record ClassFileInfo(String fileName, byte[] bytecodes) {}
|
||||
record ClassFileInfo(String fileName, byte[] bytecodes) {}
|
||||
}
|
||||
|
|
|
@ -51,8 +51,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
|
|||
void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
|
||||
}
|
||||
|
||||
public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
|
||||
public void testInstrumentationTargetLookup() throws IOException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);
|
||||
|
||||
assertThat(checkMethods, aMapWithSize(3));
|
||||
assertThat(
|
||||
|
@ -116,8 +116,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
|
||||
public void testInstrumentationTargetLookupWithOverloads() throws IOException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);
|
||||
|
||||
assertThat(checkMethods, aMapWithSize(2));
|
||||
assertThat(
|
||||
|
@ -148,8 +148,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
|
||||
public void testInstrumentationTargetLookupWithCtors() throws IOException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);
|
||||
|
||||
assertThat(checkMethods, aMapWithSize(2));
|
||||
assertThat(
|
||||
|
|
|
@ -12,31 +12,64 @@ package org.elasticsearch.entitlement.instrumentation.impl;
|
|||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
|
||||
import org.elasticsearch.entitlement.instrumentation.MethodKey;
|
||||
import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo;
|
||||
import org.elasticsearch.logging.LogManager;
|
||||
import org.elasticsearch.logging.Logger;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.Before;
|
||||
import org.objectweb.asm.Type;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.AccessFlag;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.List;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
|
||||
import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
|
||||
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
|
||||
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
|
||||
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
/**
|
||||
* This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check
|
||||
* some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
|
||||
* This tests {@link InstrumenterImpl} can instrument various method signatures
|
||||
* (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
|
||||
*/
|
||||
@ESTestCase.WithoutSecurityManager
|
||||
public class InstrumenterTests extends ESTestCase {
|
||||
private static final Logger logger = LogManager.getLogger(InstrumenterTests.class);
|
||||
|
||||
static class TestLoader extends ClassLoader {
|
||||
final byte[] testClassBytes;
|
||||
final Class<?> testClass;
|
||||
|
||||
TestLoader(String testClassName, byte[] testClassBytes) {
|
||||
super(InstrumenterTests.class.getClassLoader());
|
||||
this.testClassBytes = testClassBytes;
|
||||
this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length);
|
||||
}
|
||||
|
||||
Method getSameMethod(Method method) {
|
||||
try {
|
||||
return testClass.getMethod(method.getName(), method.getParameterTypes());
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
Constructor<?> getSameConstructor(Constructor<?> ctor) {
|
||||
try {
|
||||
return testClass.getConstructor(ctor.getParameterTypes());
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Contains all the virtual methods from {@link TestClassToInstrument},
|
||||
* allowing this test to call them on the dynamically loaded instrumented class.
|
||||
|
@ -80,13 +113,15 @@ public class InstrumenterTests extends ESTestCase {
|
|||
public interface MockEntitlementChecker {
|
||||
void checkSomeStaticMethod(Class<?> clazz, int arg);
|
||||
|
||||
void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
|
||||
void checkSomeStaticMethodOverload(Class<?> clazz, int arg, String anotherArg);
|
||||
|
||||
void checkAnotherStaticMethod(Class<?> clazz, int arg);
|
||||
|
||||
void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
|
||||
|
||||
void checkCtor(Class<?> clazz);
|
||||
|
||||
void checkCtor(Class<?> clazz, int arg);
|
||||
void checkCtorOverload(Class<?> clazz, int arg);
|
||||
}
|
||||
|
||||
public static class TestEntitlementCheckerHolder {
|
||||
|
@ -105,6 +140,7 @@ public class InstrumenterTests extends ESTestCase {
|
|||
volatile boolean isActive;
|
||||
|
||||
int checkSomeStaticMethodIntCallCount = 0;
|
||||
int checkAnotherStaticMethodIntCallCount = 0;
|
||||
int checkSomeStaticMethodIntStringCallCount = 0;
|
||||
int checkSomeInstanceMethodCallCount = 0;
|
||||
|
||||
|
@ -120,28 +156,33 @@ public class InstrumenterTests extends ESTestCase {
|
|||
@Override
|
||||
public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
|
||||
checkSomeStaticMethodIntCallCount++;
|
||||
assertSame(TestMethodUtils.class, callerClass);
|
||||
assertSame(InstrumenterTests.class, callerClass);
|
||||
assertEquals(123, arg);
|
||||
throwIfActive();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
|
||||
public void checkSomeStaticMethodOverload(Class<?> callerClass, int arg, String anotherArg) {
|
||||
checkSomeStaticMethodIntStringCallCount++;
|
||||
assertSame(TestMethodUtils.class, callerClass);
|
||||
assertSame(InstrumenterTests.class, callerClass);
|
||||
assertEquals(123, arg);
|
||||
assertEquals("abc", anotherArg);
|
||||
throwIfActive();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkAnotherStaticMethod(Class<?> callerClass, int arg) {
|
||||
checkAnotherStaticMethodIntCallCount++;
|
||||
assertSame(InstrumenterTests.class, callerClass);
|
||||
assertEquals(123, arg);
|
||||
throwIfActive();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
|
||||
checkSomeInstanceMethodCallCount++;
|
||||
assertSame(InstrumenterTests.class, callerClass);
|
||||
assertThat(
|
||||
that.getClass().getName(),
|
||||
startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument")
|
||||
);
|
||||
assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName()));
|
||||
assertEquals(123, arg);
|
||||
assertEquals("def", anotherArg);
|
||||
throwIfActive();
|
||||
|
@ -155,7 +196,7 @@ public class InstrumenterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void checkCtor(Class<?> callerClass, int arg) {
|
||||
public void checkCtorOverload(Class<?> callerClass, int arg) {
|
||||
checkCtorIntCallCount++;
|
||||
assertSame(InstrumenterTests.class, callerClass);
|
||||
assertEquals(123, arg);
|
||||
|
@ -163,206 +204,83 @@ public class InstrumenterTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testClassIsInstrumented() throws Exception {
|
||||
var classToInstrument = TestClassToInstrument.class;
|
||||
@Before
|
||||
public void resetInstance() {
|
||||
TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker();
|
||||
}
|
||||
|
||||
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
|
||||
Map<MethodKey, CheckMethod> checkMethods = Map.of(
|
||||
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
|
||||
checkMethod
|
||||
);
|
||||
|
||||
var instrumenter = createInstrumenter(checkMethods);
|
||||
|
||||
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
|
||||
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
|
||||
}
|
||||
|
||||
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
|
||||
classToInstrument.getName() + "_NEW",
|
||||
newBytecode
|
||||
);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = false;
|
||||
public void testStaticMethod() throws Exception {
|
||||
Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
|
||||
TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)));
|
||||
|
||||
// Before checking is active, nothing should throw
|
||||
callStaticMethod(newClass, "someStaticMethod", 123);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
|
||||
assertStaticMethod(loader, targetMethod, 123);
|
||||
// After checking is activated, everything should throw
|
||||
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
|
||||
assertStaticMethodThrows(loader, targetMethod, 123);
|
||||
}
|
||||
|
||||
public void testClassIsNotInstrumentedTwice() throws Exception {
|
||||
var classToInstrument = TestClassToInstrument.class;
|
||||
public void testNotInstrumentedTwice() throws Exception {
|
||||
Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
|
||||
var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod));
|
||||
|
||||
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
|
||||
Map<MethodKey, CheckMethod> checkMethods = Map.of(
|
||||
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
|
||||
checkMethod
|
||||
);
|
||||
|
||||
var instrumenter = createInstrumenter(checkMethods);
|
||||
|
||||
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
|
||||
var internalClassName = Type.getInternalName(classToInstrument);
|
||||
|
||||
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
|
||||
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
|
||||
|
||||
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
|
||||
var loader1 = instrumentTestClass(instrumenter);
|
||||
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes);
|
||||
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
|
||||
var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode);
|
||||
|
||||
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
|
||||
classToInstrument.getName() + "_NEW_NEW",
|
||||
instrumentedTwiceBytecode
|
||||
);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
|
||||
|
||||
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
|
||||
assertStaticMethodThrows(loader2, targetMethod, 123);
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
|
||||
}
|
||||
|
||||
public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
|
||||
var classToInstrument = TestClassToInstrument.class;
|
||||
public void testMultipleMethods() throws Exception {
|
||||
Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
|
||||
Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class);
|
||||
|
||||
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
|
||||
Map<MethodKey, CheckMethod> checkMethods = Map.of(
|
||||
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
|
||||
checkMethod,
|
||||
methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)),
|
||||
checkMethod
|
||||
);
|
||||
var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2));
|
||||
var loader = instrumentTestClass(instrumenter);
|
||||
|
||||
var instrumenter = createInstrumenter(checkMethods);
|
||||
|
||||
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
|
||||
var internalClassName = Type.getInternalName(classToInstrument);
|
||||
|
||||
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
|
||||
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
|
||||
|
||||
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
|
||||
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
|
||||
|
||||
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
|
||||
classToInstrument.getName() + "_NEW_NEW",
|
||||
instrumentedTwiceBytecode
|
||||
);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
|
||||
|
||||
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
|
||||
assertStaticMethodThrows(loader, targetMethod1, 123);
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
|
||||
|
||||
assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123));
|
||||
assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
|
||||
assertStaticMethodThrows(loader, targetMethod2, 123);
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount);
|
||||
}
|
||||
|
||||
public void testInstrumenterWorksWithOverloads() throws Exception {
|
||||
var classToInstrument = TestClassToInstrument.class;
|
||||
|
||||
Map<MethodKey, CheckMethod> checkMethods = Map.of(
|
||||
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
|
||||
getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
|
||||
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
|
||||
getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
|
||||
public void testStaticMethodOverload() throws Exception {
|
||||
Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
|
||||
Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class);
|
||||
var instrumenter = createInstrumenter(
|
||||
Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2)
|
||||
);
|
||||
var loader = instrumentTestClass(instrumenter);
|
||||
|
||||
var instrumenter = createInstrumenter(checkMethods);
|
||||
|
||||
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
|
||||
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
|
||||
}
|
||||
|
||||
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
|
||||
classToInstrument.getName() + "_NEW",
|
||||
newBytecode
|
||||
);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
|
||||
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0;
|
||||
|
||||
// After checking is activated, everything should throw
|
||||
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
|
||||
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
|
||||
|
||||
assertStaticMethodThrows(loader, targetMethod1, 123);
|
||||
assertStaticMethodThrows(loader, targetMethod2, 123, "abc");
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount);
|
||||
}
|
||||
|
||||
public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
|
||||
var classToInstrument = TestClassToInstrument.class;
|
||||
|
||||
Map<MethodKey, CheckMethod> checkMethods = Map.of(
|
||||
methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
|
||||
getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
|
||||
);
|
||||
|
||||
var instrumenter = createInstrumenter(checkMethods);
|
||||
|
||||
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
|
||||
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
|
||||
}
|
||||
|
||||
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
|
||||
classToInstrument.getName() + "_NEW",
|
||||
newBytecode
|
||||
);
|
||||
public void testInstanceMethodOverload() throws Exception {
|
||||
Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class);
|
||||
var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod));
|
||||
var loader = instrumentTestClass(instrumenter);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0;
|
||||
|
||||
Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
|
||||
Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance());
|
||||
|
||||
// This overload is not instrumented, so it will not throw
|
||||
testTargetClass.someMethod(123);
|
||||
assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
|
||||
expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
|
||||
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount);
|
||||
}
|
||||
|
||||
public void testInstrumenterWorksWithConstructors() throws Exception {
|
||||
var classToInstrument = TestClassToInstrument.class;
|
||||
|
||||
Map<MethodKey, CheckMethod> checkMethods = Map.of(
|
||||
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
|
||||
getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
|
||||
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
|
||||
getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
|
||||
);
|
||||
|
||||
var instrumenter = createInstrumenter(checkMethods);
|
||||
|
||||
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
|
||||
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
|
||||
}
|
||||
|
||||
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
|
||||
classToInstrument.getName() + "_NEW",
|
||||
newBytecode
|
||||
);
|
||||
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
|
||||
var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
|
||||
assertThat(ex.getCause(), instanceOf(TestException.class));
|
||||
var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
|
||||
assertThat(ex2.getCause(), instanceOf(TestException.class));
|
||||
public void testConstructors() throws Exception {
|
||||
Constructor<?> ctor1 = TestClassToInstrument.class.getConstructor();
|
||||
Constructor<?> ctor2 = TestClassToInstrument.class.getConstructor(int.class);
|
||||
var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2)));
|
||||
|
||||
assertCtorThrows(loader, ctor1);
|
||||
assertCtorThrows(loader, ctor2, 123);
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount);
|
||||
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount);
|
||||
}
|
||||
|
@ -373,11 +291,107 @@ public class InstrumenterTests extends ESTestCase {
|
|||
* MethodKey and instrumentationMethod with slightly different signatures (using the common interface
|
||||
* Testable) which is not what would happen when it's run by the agent.
|
||||
*/
|
||||
private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
|
||||
private static InstrumenterImpl createInstrumenter(Map<String, Executable> methods) throws NoSuchMethodException {
|
||||
Map<MethodKey, CheckMethod> checkMethods = new HashMap<>();
|
||||
for (var entry : methods.entrySet()) {
|
||||
checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue()));
|
||||
}
|
||||
String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class);
|
||||
String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class);
|
||||
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
|
||||
|
||||
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods);
|
||||
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
|
||||
}
|
||||
|
||||
private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException {
|
||||
var clazz = TestClassToInstrument.class;
|
||||
ClassFileInfo initial = getClassFileInfo(clazz);
|
||||
byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes());
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
|
||||
}
|
||||
return new TestLoader(clazz.getName(), newBytecode);
|
||||
}
|
||||
|
||||
private static MethodKey getMethodKey(Executable method) {
|
||||
logger.info("method key: {}", method.getName());
|
||||
String methodName = method instanceof Constructor<?> ? "<init>" : method.getName();
|
||||
return new MethodKey(
|
||||
Type.getInternalName(method.getDeclaringClass()),
|
||||
methodName,
|
||||
Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList()
|
||||
);
|
||||
}
|
||||
|
||||
private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException {
|
||||
Set<AccessFlag> flags = targetMethod.accessFlags();
|
||||
boolean isInstance = flags.contains(AccessFlag.STATIC) == false && targetMethod instanceof Method;
|
||||
int extraArgs = 1; // caller class
|
||||
if (isInstance) {
|
||||
++extraArgs;
|
||||
}
|
||||
Class<?>[] targetParameterTypes = targetMethod.getParameterTypes();
|
||||
Class<?>[] checkParameterTypes = new Class<?>[targetParameterTypes.length + extraArgs];
|
||||
checkParameterTypes[0] = Class.class;
|
||||
if (isInstance) {
|
||||
checkParameterTypes[1] = Testable.class;
|
||||
}
|
||||
System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length);
|
||||
var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes);
|
||||
return new CheckMethod(
|
||||
Type.getInternalName(MockEntitlementChecker.class),
|
||||
checkMethod.getName(),
|
||||
Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList()
|
||||
);
|
||||
}
|
||||
|
||||
private static void unwrapInvocationException(InvocationTargetException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof TestException n) {
|
||||
// Sometimes we're expecting this one!
|
||||
throw n;
|
||||
} else {
|
||||
throw new AssertionError(cause);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calling a static method of a dynamically loaded class is significantly more cumbersome
|
||||
* than calling a virtual method.
|
||||
*/
|
||||
static void callStaticMethod(Method method, Object... args) {
|
||||
try {
|
||||
method.invoke(null, args);
|
||||
} catch (InvocationTargetException e) {
|
||||
unwrapInvocationException(e);
|
||||
} catch (IllegalAccessException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) {
|
||||
Method testMethod = loader.getSameMethod(method);
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
expectThrows(TestException.class, () -> callStaticMethod(testMethod, args));
|
||||
}
|
||||
|
||||
private void assertStaticMethod(TestLoader loader, Method method, Object... args) {
|
||||
Method testMethod = loader.getSameMethod(method);
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = false;
|
||||
callStaticMethod(testMethod, args);
|
||||
}
|
||||
|
||||
private void assertCtorThrows(TestLoader loader, Constructor<?> ctor, Object... args) {
|
||||
Constructor<?> testCtor = loader.getSameConstructor(ctor);
|
||||
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
|
||||
expectThrows(TestException.class, () -> {
|
||||
try {
|
||||
testCtor.newInstance(args);
|
||||
} catch (InvocationTargetException e) {
|
||||
unwrapInvocationException(e);
|
||||
} catch (IllegalAccessException | InstantiationException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.entitlement.instrumentation.impl;
|
||||
|
||||
class TestLoader extends ClassLoader {
|
||||
TestLoader(ClassLoader parent) {
|
||||
super(parent);
|
||||
}
|
||||
|
||||
public Class<?> defineClassFromBytes(String name, byte[] bytes) {
|
||||
return defineClass(name, bytes, 0, bytes.length);
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.entitlement.instrumentation.impl;
|
||||
|
||||
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
|
||||
import org.elasticsearch.entitlement.instrumentation.MethodKey;
|
||||
import org.objectweb.asm.Type;
|
||||
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
class TestMethodUtils {
|
||||
|
||||
/**
|
||||
* @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
|
||||
*/
|
||||
static MethodKey methodKeyForTarget(Method targetMethod) {
|
||||
Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
|
||||
return new MethodKey(
|
||||
Type.getInternalName(targetMethod.getDeclaringClass()),
|
||||
targetMethod.getName(),
|
||||
Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
|
||||
);
|
||||
}
|
||||
|
||||
static MethodKey methodKeyForConstructor(Class<?> classToInstrument, List<String> params) {
|
||||
return new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", params);
|
||||
}
|
||||
|
||||
static CheckMethod getCheckMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) throws NoSuchMethodException {
|
||||
var method = clazz.getMethod(methodName, parameterTypes);
|
||||
return new CheckMethod(
|
||||
Type.getInternalName(clazz),
|
||||
method.getName(),
|
||||
Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calling a static method of a dynamically loaded class is significantly more cumbersome
|
||||
* than calling a virtual method.
|
||||
*/
|
||||
static void callStaticMethod(Class<?> c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException {
|
||||
try {
|
||||
c.getMethod(methodName, int.class).invoke(null, arg);
|
||||
} catch (InvocationTargetException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof TestException n) {
|
||||
// Sometimes we're expecting this one!
|
||||
throw n;
|
||||
} else {
|
||||
throw new AssertionError(cause);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void callStaticMethod(Class<?> c, String methodName, int arg1, String arg2) throws NoSuchMethodException,
|
||||
IllegalAccessException {
|
||||
try {
|
||||
c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2);
|
||||
} catch (InvocationTargetException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof TestException n) {
|
||||
// Sometimes we're expecting this one!
|
||||
throw n;
|
||||
} else {
|
||||
throw new AssertionError(cause);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap;
|
|||
import org.elasticsearch.entitlement.bridge.EntitlementChecker;
|
||||
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
|
||||
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
|
||||
import org.elasticsearch.entitlement.instrumentation.Instrumenter;
|
||||
import org.elasticsearch.entitlement.instrumentation.MethodKey;
|
||||
import org.elasticsearch.entitlement.instrumentation.Transformer;
|
||||
import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker;
|
||||
|
@ -64,13 +65,12 @@ public class EntitlementInitialization {
|
|||
public static void initialize(Instrumentation inst) throws Exception {
|
||||
manager = initChecker();
|
||||
|
||||
Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument(
|
||||
"org.elasticsearch.entitlement.bridge.EntitlementChecker"
|
||||
);
|
||||
Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethods(EntitlementChecker.class);
|
||||
|
||||
var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());
|
||||
|
||||
inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true);
|
||||
Instrumenter instrumenter = INSTRUMENTER_FACTORY.newInstrumenter(EntitlementChecker.class, checkMethods);
|
||||
inst.addTransformer(new Transformer(instrumenter, classesToTransform), true);
|
||||
// TODO: should we limit this array somehow?
|
||||
var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new);
|
||||
inst.retransformClasses(classesToRetransform);
|
||||
|
|
|
@ -16,7 +16,7 @@ import java.util.Map;
|
|||
* The SPI service entry point for instrumentation.
|
||||
*/
|
||||
public interface InstrumentationService {
|
||||
Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods);
|
||||
Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods);
|
||||
|
||||
Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException;
|
||||
Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws IOException;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue