From ef85d0a53f1f58a63359b63933fc1e147167d42f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lorenzo=20Dematt=C3=A9?= Date: Mon, 28 Oct 2024 09:31:19 +0100 Subject: [PATCH] Avoid double instrumentation via class annotation (#115398) --- .../impl/InstrumenterImpl.java | 96 ++++++++++++--- .../impl/InstrumenterTests.java | 112 +++++++++++++++--- 2 files changed, 177 insertions(+), 31 deletions(-) diff --git a/distribution/tools/entitlement-agent/impl/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java b/distribution/tools/entitlement-agent/impl/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java index 81c120ddcd6d..7c2e1645ada8 100644 --- a/distribution/tools/entitlement-agent/impl/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java +++ b/distribution/tools/entitlement-agent/impl/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java @@ -15,8 +15,10 @@ import org.objectweb.asm.AnnotationVisitor; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.FieldVisitor; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; +import org.objectweb.asm.RecordComponentVisitor; import org.objectweb.asm.Type; import java.io.IOException; @@ -73,7 +75,13 @@ public class InstrumenterImpl implements Instrumenter { } class EntitlementClassVisitor extends ClassVisitor { - final String className; + + private static final String ENTITLEMENT_ANNOTATION = "EntitlementInstrumented"; + + private final String className; + + private boolean isAnnotationPresent; + private boolean annotationNeeded = true; EntitlementClassVisitor(int api, ClassVisitor classVisitor, String className) { super(api, classVisitor); @@ -85,25 +93,85 @@ public class InstrumenterImpl implements Instrumenter { super.visit(version, access, name + classNameSuffix, signature, superName, interfaces); } + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + if (visible && descriptor.equals(ENTITLEMENT_ANNOTATION)) { + isAnnotationPresent = true; + annotationNeeded = false; + } + return cv.visitAnnotation(descriptor, visible); + } + + @Override + public void visitNestMember(String nestMember) { + addClassAnnotationIfNeeded(); + super.visitNestMember(nestMember); + } + + @Override + public void visitPermittedSubclass(String permittedSubclass) { + addClassAnnotationIfNeeded(); + super.visitPermittedSubclass(permittedSubclass); + } + + @Override + public void visitInnerClass(String name, String outerName, String innerName, int access) { + addClassAnnotationIfNeeded(); + super.visitInnerClass(name, outerName, innerName, access); + } + + @Override + public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) { + addClassAnnotationIfNeeded(); + return super.visitField(access, name, descriptor, signature, value); + } + + @Override + public RecordComponentVisitor visitRecordComponent(String name, String descriptor, String signature) { + addClassAnnotationIfNeeded(); + return super.visitRecordComponent(name, descriptor, signature); + } + @Override public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { + addClassAnnotationIfNeeded(); var mv = super.visitMethod(access, name, descriptor, signature, exceptions); - boolean isStatic = (access & ACC_STATIC) != 0; - var key = new MethodKey( - className, - name, - Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList(), - isStatic - ); - var instrumentationMethod = instrumentationMethods.get(key); - if (instrumentationMethod != null) { - // LOGGER.debug("Will instrument method {}", key); - return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod); - } else { - // LOGGER.trace("Will not instrument method {}", key); + if (isAnnotationPresent == false) { + boolean isStatic = (access & ACC_STATIC) != 0; + var key = new MethodKey( + className, + name, + Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList(), + isStatic + ); + var instrumentationMethod = instrumentationMethods.get(key); + if (instrumentationMethod != null) { + // LOGGER.debug("Will instrument method {}", key); + return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod); + } else { + // LOGGER.trace("Will not instrument method {}", key); + } } return mv; } + + /** + * A class annotation can be added via visitAnnotation; we need to call visitAnnotation after all other visitAnnotation + * calls (in case one of them detects our annotation is already present), but before any other subsequent visit* method is called + * (up to visitMethod -- if no visitMethod is called, there is nothing to instrument). + * This includes visitNestMember, visitPermittedSubclass, visitInnerClass, visitField, visitRecordComponent and, of course, + * visitMethod (see {@link ClassVisitor} javadoc). + */ + private void addClassAnnotationIfNeeded() { + if (annotationNeeded) { + // logger.debug("Adding {} annotation", ENTITLEMENT_ANNOTATION); + AnnotationVisitor av = cv.visitAnnotation(ENTITLEMENT_ANNOTATION, true); + if (av != null) { + av.visitEnd(); + } + annotationNeeded = false; + } + } } static class EntitlementMethodVisitor extends MethodVisitor { diff --git a/distribution/tools/entitlement-agent/impl/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java b/distribution/tools/entitlement-agent/impl/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java index e807ecee4f10..f05c7ccae62e 100644 --- a/distribution/tools/entitlement-agent/impl/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java +++ b/distribution/tools/entitlement-agent/impl/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java @@ -9,20 +9,24 @@ package org.elasticsearch.entitlement.instrumentation.impl; +import org.elasticsearch.common.Strings; import org.elasticsearch.entitlement.api.EntitlementChecks; import org.elasticsearch.entitlement.api.EntitlementProvider; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; -import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; import org.junit.Before; +import org.objectweb.asm.Type; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.Map; +import java.util.Arrays; +import java.util.stream.Collectors; import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text; +import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo; +import static org.hamcrest.Matchers.is; /** * This tests {@link InstrumenterImpl} in isolation, without a java agent. @@ -60,6 +64,10 @@ public class InstrumenterTests extends ESTestCase { public static void systemExit(int status) { assertEquals(123, status); } + + public static void anotherSystemExit(int status) { + assertEquals(123, status); + } } static final class TestException extends RuntimeException {} @@ -76,8 +84,11 @@ public class InstrumenterTests extends ESTestCase { */ volatile boolean isActive; + int checkSystemExitCallCount = 0; + @Override public void checkSystemExit(Class callerClass, int status) { + checkSystemExitCallCount++; assertSame(InstrumenterTests.class, callerClass); assertEquals(123, status); throwIfActive(); @@ -90,18 +101,11 @@ public class InstrumenterTests extends ESTestCase { } } - public void test() throws Exception { - // This test doesn't replace ClassToInstrument in-place but instead loads a separate - // class ClassToInstrument_NEW that contains the instrumentation. Because of this, - // we need to configure the Transformer to use a MethodKey and instrumentationMethod - // with slightly different signatures (using the common interface Testable) which - // is not what would happen when it's run by the agent. + public void testClassIsInstrumented() throws Exception { + var classToInstrument = ClassToInstrument.class; + var instrumenter = createInstrumenter(classToInstrument, "systemExit"); - MethodKey k1 = instrumentationService.methodKeyForTarget(ClassToInstrument.class.getMethod("systemExit", int.class)); - Method v1 = EntitlementChecks.class.getMethod("checkSystemExit", Class.class, int.class); - var instrumenter = new InstrumenterImpl("_NEW", Map.of(k1, v1)); - - byte[] newBytecode = instrumenter.instrumentClassFile(ClassToInstrument.class).bytecodes(); + byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); if (logger.isTraceEnabled()) { logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); @@ -112,22 +116,96 @@ public class InstrumenterTests extends ESTestCase { newBytecode ); + getTestChecks().isActive = false; + // Before checking is active, nothing should throw - callStaticSystemExit(newClass, 123); + callStaticMethod(newClass, "systemExit", 123); getTestChecks().isActive = true; // After checking is activated, everything should throw - assertThrows(TestException.class, () -> callStaticSystemExit(newClass, 123)); + assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123)); + } + + public void testClassIsNotInstrumentedTwice() throws Exception { + var classToInstrument = ClassToInstrument.class; + var instrumenter = createInstrumenter(classToInstrument, "systemExit"); + + InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); + var internalClassName = Type.getInternalName(classToInstrument); + + byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes()); + byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode); + + logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode))); + logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); + + Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( + ClassToInstrument.class.getName() + "_NEW_NEW", + instrumentedTwiceBytecode + ); + + getTestChecks().isActive = true; + getTestChecks().checkSystemExitCallCount = 0; + + assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123)); + assertThat(getTestChecks().checkSystemExitCallCount, is(1)); + } + + public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception { + var classToInstrument = ClassToInstrument.class; + var instrumenter = createInstrumenter(classToInstrument, "systemExit", "anotherSystemExit"); + + InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); + var internalClassName = Type.getInternalName(classToInstrument); + + byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes()); + byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode); + + logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode))); + logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); + + Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( + ClassToInstrument.class.getName() + "_NEW_NEW", + instrumentedTwiceBytecode + ); + + getTestChecks().isActive = true; + getTestChecks().checkSystemExitCallCount = 0; + + assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123)); + assertThat(getTestChecks().checkSystemExitCallCount, is(1)); + + assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherSystemExit", 123)); + assertThat(getTestChecks().checkSystemExitCallCount, is(2)); + } + + /** This test doesn't replace ClassToInstrument in-place but instead loads a separate + * class ClassToInstrument_NEW that contains the instrumentation. Because of this, + * we need to configure the Transformer to use a MethodKey and instrumentationMethod + * with slightly different signatures (using the common interface Testable) which + * is not what would happen when it's run by the agent. + */ + private InstrumenterImpl createInstrumenter(Class classToInstrument, String... methodNames) throws NoSuchMethodException { + Method v1 = EntitlementChecks.class.getMethod("checkSystemExit", Class.class, int.class); + var methods = Arrays.stream(methodNames).map(name -> { + try { + return instrumentationService.methodKeyForTarget(classToInstrument.getMethod(name, int.class)); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + }).collect(Collectors.toUnmodifiableMap(name -> name, name -> v1)); + + return new InstrumenterImpl("_NEW", methods); } /** * Calling a static method of a dynamically loaded class is significantly more cumbersome * than calling a virtual method. */ - private static void callStaticSystemExit(Class c, int status) throws NoSuchMethodException, IllegalAccessException { + private static void callStaticMethod(Class c, String methodName, int status) throws NoSuchMethodException, IllegalAccessException { try { - c.getMethod("systemExit", int.class).invoke(null, status); + c.getMethod(methodName, int.class).invoke(null, status); } catch (InvocationTargetException e) { Throwable cause = e.getCause(); if (cause instanceof TestException n) {