Avoid double instrumentation via class annotation (#115398)

This commit is contained in:
Lorenzo Dematté 2024-10-28 09:31:19 +01:00 committed by GitHub
parent 98cd34f3fd
commit ef85d0a53f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 177 additions and 31 deletions

View file

@ -15,8 +15,10 @@ import org.objectweb.asm.AnnotationVisitor;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.RecordComponentVisitor;
import org.objectweb.asm.Type;
import java.io.IOException;
@ -73,7 +75,13 @@ public class InstrumenterImpl implements Instrumenter {
}
class EntitlementClassVisitor extends ClassVisitor {
final String className;
private static final String ENTITLEMENT_ANNOTATION = "EntitlementInstrumented";
private final String className;
private boolean isAnnotationPresent;
private boolean annotationNeeded = true;
EntitlementClassVisitor(int api, ClassVisitor classVisitor, String className) {
super(api, classVisitor);
@ -85,25 +93,85 @@ public class InstrumenterImpl implements Instrumenter {
super.visit(version, access, name + classNameSuffix, signature, superName, interfaces);
}
@Override
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
if (visible && descriptor.equals(ENTITLEMENT_ANNOTATION)) {
isAnnotationPresent = true;
annotationNeeded = false;
}
return cv.visitAnnotation(descriptor, visible);
}
@Override
public void visitNestMember(String nestMember) {
addClassAnnotationIfNeeded();
super.visitNestMember(nestMember);
}
@Override
public void visitPermittedSubclass(String permittedSubclass) {
addClassAnnotationIfNeeded();
super.visitPermittedSubclass(permittedSubclass);
}
@Override
public void visitInnerClass(String name, String outerName, String innerName, int access) {
addClassAnnotationIfNeeded();
super.visitInnerClass(name, outerName, innerName, access);
}
@Override
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
addClassAnnotationIfNeeded();
return super.visitField(access, name, descriptor, signature, value);
}
@Override
public RecordComponentVisitor visitRecordComponent(String name, String descriptor, String signature) {
addClassAnnotationIfNeeded();
return super.visitRecordComponent(name, descriptor, signature);
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
addClassAnnotationIfNeeded();
var mv = super.visitMethod(access, name, descriptor, signature, exceptions);
boolean isStatic = (access & ACC_STATIC) != 0;
var key = new MethodKey(
className,
name,
Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList(),
isStatic
);
var instrumentationMethod = instrumentationMethods.get(key);
if (instrumentationMethod != null) {
// LOGGER.debug("Will instrument method {}", key);
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod);
} else {
// LOGGER.trace("Will not instrument method {}", key);
if (isAnnotationPresent == false) {
boolean isStatic = (access & ACC_STATIC) != 0;
var key = new MethodKey(
className,
name,
Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList(),
isStatic
);
var instrumentationMethod = instrumentationMethods.get(key);
if (instrumentationMethod != null) {
// LOGGER.debug("Will instrument method {}", key);
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod);
} else {
// LOGGER.trace("Will not instrument method {}", key);
}
}
return mv;
}
/**
* A class annotation can be added via visitAnnotation; we need to call visitAnnotation after all other visitAnnotation
* calls (in case one of them detects our annotation is already present), but before any other subsequent visit* method is called
* (up to visitMethod -- if no visitMethod is called, there is nothing to instrument).
* This includes visitNestMember, visitPermittedSubclass, visitInnerClass, visitField, visitRecordComponent and, of course,
* visitMethod (see {@link ClassVisitor} javadoc).
*/
private void addClassAnnotationIfNeeded() {
if (annotationNeeded) {
// logger.debug("Adding {} annotation", ENTITLEMENT_ANNOTATION);
AnnotationVisitor av = cv.visitAnnotation(ENTITLEMENT_ANNOTATION, true);
if (av != null) {
av.visitEnd();
}
annotationNeeded = false;
}
}
}
static class EntitlementMethodVisitor extends MethodVisitor {

View file

@ -9,20 +9,24 @@
package org.elasticsearch.entitlement.instrumentation.impl;
import org.elasticsearch.common.Strings;
import org.elasticsearch.entitlement.api.EntitlementChecks;
import org.elasticsearch.entitlement.api.EntitlementProvider;
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.objectweb.asm.Type;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Arrays;
import java.util.stream.Collectors;
import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
import static org.hamcrest.Matchers.is;
/**
* This tests {@link InstrumenterImpl} in isolation, without a java agent.
@ -60,6 +64,10 @@ public class InstrumenterTests extends ESTestCase {
public static void systemExit(int status) {
assertEquals(123, status);
}
public static void anotherSystemExit(int status) {
assertEquals(123, status);
}
}
static final class TestException extends RuntimeException {}
@ -76,8 +84,11 @@ public class InstrumenterTests extends ESTestCase {
*/
volatile boolean isActive;
int checkSystemExitCallCount = 0;
@Override
public void checkSystemExit(Class<?> callerClass, int status) {
checkSystemExitCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, status);
throwIfActive();
@ -90,18 +101,11 @@ public class InstrumenterTests extends ESTestCase {
}
}
public void test() throws Exception {
// This test doesn't replace ClassToInstrument in-place but instead loads a separate
// class ClassToInstrument_NEW that contains the instrumentation. Because of this,
// we need to configure the Transformer to use a MethodKey and instrumentationMethod
// with slightly different signatures (using the common interface Testable) which
// is not what would happen when it's run by the agent.
public void testClassIsInstrumented() throws Exception {
var classToInstrument = ClassToInstrument.class;
var instrumenter = createInstrumenter(classToInstrument, "systemExit");
MethodKey k1 = instrumentationService.methodKeyForTarget(ClassToInstrument.class.getMethod("systemExit", int.class));
Method v1 = EntitlementChecks.class.getMethod("checkSystemExit", Class.class, int.class);
var instrumenter = new InstrumenterImpl("_NEW", Map.of(k1, v1));
byte[] newBytecode = instrumenter.instrumentClassFile(ClassToInstrument.class).bytecodes();
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
@ -112,22 +116,96 @@ public class InstrumenterTests extends ESTestCase {
newBytecode
);
getTestChecks().isActive = false;
// Before checking is active, nothing should throw
callStaticSystemExit(newClass, 123);
callStaticMethod(newClass, "systemExit", 123);
getTestChecks().isActive = true;
// After checking is activated, everything should throw
assertThrows(TestException.class, () -> callStaticSystemExit(newClass, 123));
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
}
public void testClassIsNotInstrumentedTwice() throws Exception {
var classToInstrument = ClassToInstrument.class;
var instrumenter = createInstrumenter(classToInstrument, "systemExit");
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
var internalClassName = Type.getInternalName(classToInstrument);
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
ClassToInstrument.class.getName() + "_NEW_NEW",
instrumentedTwiceBytecode
);
getTestChecks().isActive = true;
getTestChecks().checkSystemExitCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
assertThat(getTestChecks().checkSystemExitCallCount, is(1));
}
public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
var classToInstrument = ClassToInstrument.class;
var instrumenter = createInstrumenter(classToInstrument, "systemExit", "anotherSystemExit");
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
var internalClassName = Type.getInternalName(classToInstrument);
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
ClassToInstrument.class.getName() + "_NEW_NEW",
instrumentedTwiceBytecode
);
getTestChecks().isActive = true;
getTestChecks().checkSystemExitCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
assertThat(getTestChecks().checkSystemExitCallCount, is(1));
assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherSystemExit", 123));
assertThat(getTestChecks().checkSystemExitCallCount, is(2));
}
/** This test doesn't replace ClassToInstrument in-place but instead loads a separate
* class ClassToInstrument_NEW that contains the instrumentation. Because of this,
* we need to configure the Transformer to use a MethodKey and instrumentationMethod
* with slightly different signatures (using the common interface Testable) which
* is not what would happen when it's run by the agent.
*/
private InstrumenterImpl createInstrumenter(Class<?> classToInstrument, String... methodNames) throws NoSuchMethodException {
Method v1 = EntitlementChecks.class.getMethod("checkSystemExit", Class.class, int.class);
var methods = Arrays.stream(methodNames).map(name -> {
try {
return instrumentationService.methodKeyForTarget(classToInstrument.getMethod(name, int.class));
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toUnmodifiableMap(name -> name, name -> v1));
return new InstrumenterImpl("_NEW", methods);
}
/**
* Calling a static method of a dynamically loaded class is significantly more cumbersome
* than calling a virtual method.
*/
private static void callStaticSystemExit(Class<?> c, int status) throws NoSuchMethodException, IllegalAccessException {
private static void callStaticMethod(Class<?> c, String methodName, int status) throws NoSuchMethodException, IllegalAccessException {
try {
c.getMethod("systemExit", int.class).invoke(null, status);
c.getMethod(methodName, int.class).invoke(null, status);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {