Merge revision ce2a7dee86 into multi-project

This commit is contained in:
Tim Vernum 2024-12-15 23:45:15 +11:00
commit affd6dfb5b
357 changed files with 9574 additions and 2238 deletions

View file

@ -29,15 +29,13 @@ import java.util.Map;
public class InstrumentationServiceImpl implements InstrumentationService {
@Override
public Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
return InstrumenterImpl.create(checkMethods);
public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods) {
return InstrumenterImpl.create(clazz, methods);
}
@Override
public Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
IOException {
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
var checkerClass = Class.forName(entitlementCheckerClassName);
var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {

View file

@ -58,30 +58,14 @@ public class InstrumenterImpl implements Instrumenter {
this.checkMethods = checkMethods;
}
static String getCheckerClassName() {
int javaVersion = Runtime.version().feature();
final String classNamePrefix;
if (javaVersion >= 23) {
classNamePrefix = "Java23";
} else {
classNamePrefix = "";
}
return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
}
public static InstrumenterImpl create(Map<MethodKey, CheckMethod> checkMethods) {
String checkerClass = getCheckerClassName();
String handleClass = checkerClass + "Handle";
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
public static InstrumenterImpl create(Class<?> checkerClass, Map<MethodKey, CheckMethod> checkMethods) {
Type checkerClassType = Type.getType(checkerClass);
String handleClass = checkerClassType.getInternalName() + "Handle";
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(checkerClassType);
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
}
public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {
ClassFileInfo initial = getClassFileInfo(clazz);
return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes()));
}
public static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
String internalName = Type.getInternalName(clazz);
String fileName = "/" + internalName + ".class";
byte[] originalBytecodes;
@ -306,5 +290,5 @@ public class InstrumenterImpl implements Instrumenter {
mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false);
}
public record ClassFileInfo(String fileName, byte[] bytecodes) {}
record ClassFileInfo(String fileName, byte[] bytecodes) {}
}

View file

@ -51,8 +51,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
}
public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
public void testInstrumentationTargetLookup() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);
assertThat(checkMethods, aMapWithSize(3));
assertThat(
@ -116,8 +116,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
);
}
public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
public void testInstrumentationTargetLookupWithOverloads() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);
assertThat(checkMethods, aMapWithSize(2));
assertThat(
@ -148,8 +148,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
);
}
public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
public void testInstrumentationTargetLookupWithCtors() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);
assertThat(checkMethods, aMapWithSize(2));
assertThat(

View file

@ -12,31 +12,64 @@ package org.elasticsearch.entitlement.instrumentation.impl;
import org.elasticsearch.common.Strings;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.objectweb.asm.Type;
import java.io.IOException;
import java.lang.reflect.AccessFlag;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.hamcrest.Matchers.equalTo;
/**
* This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check
* some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
* This tests {@link InstrumenterImpl} can instrument various method signatures
* (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
*/
@ESTestCase.WithoutSecurityManager
public class InstrumenterTests extends ESTestCase {
private static final Logger logger = LogManager.getLogger(InstrumenterTests.class);
static class TestLoader extends ClassLoader {
final byte[] testClassBytes;
final Class<?> testClass;
TestLoader(String testClassName, byte[] testClassBytes) {
super(InstrumenterTests.class.getClassLoader());
this.testClassBytes = testClassBytes;
this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length);
}
Method getSameMethod(Method method) {
try {
return testClass.getMethod(method.getName(), method.getParameterTypes());
} catch (NoSuchMethodException e) {
throw new AssertionError(e);
}
}
Constructor<?> getSameConstructor(Constructor<?> ctor) {
try {
return testClass.getConstructor(ctor.getParameterTypes());
} catch (NoSuchMethodException e) {
throw new AssertionError(e);
}
}
}
/**
* Contains all the virtual methods from {@link TestClassToInstrument},
* allowing this test to call them on the dynamically loaded instrumented class.
@ -80,13 +113,15 @@ public class InstrumenterTests extends ESTestCase {
public interface MockEntitlementChecker {
void checkSomeStaticMethod(Class<?> clazz, int arg);
void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
void checkSomeStaticMethodOverload(Class<?> clazz, int arg, String anotherArg);
void checkAnotherStaticMethod(Class<?> clazz, int arg);
void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
void checkCtor(Class<?> clazz);
void checkCtor(Class<?> clazz, int arg);
void checkCtorOverload(Class<?> clazz, int arg);
}
public static class TestEntitlementCheckerHolder {
@ -105,6 +140,7 @@ public class InstrumenterTests extends ESTestCase {
volatile boolean isActive;
int checkSomeStaticMethodIntCallCount = 0;
int checkAnotherStaticMethodIntCallCount = 0;
int checkSomeStaticMethodIntStringCallCount = 0;
int checkSomeInstanceMethodCallCount = 0;
@ -120,28 +156,33 @@ public class InstrumenterTests extends ESTestCase {
@Override
public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
checkSomeStaticMethodIntCallCount++;
assertSame(TestMethodUtils.class, callerClass);
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
throwIfActive();
}
@Override
public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
public void checkSomeStaticMethodOverload(Class<?> callerClass, int arg, String anotherArg) {
checkSomeStaticMethodIntStringCallCount++;
assertSame(TestMethodUtils.class, callerClass);
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
assertEquals("abc", anotherArg);
throwIfActive();
}
@Override
public void checkAnotherStaticMethod(Class<?> callerClass, int arg) {
checkAnotherStaticMethodIntCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
throwIfActive();
}
@Override
public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
checkSomeInstanceMethodCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertThat(
that.getClass().getName(),
startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument")
);
assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName()));
assertEquals(123, arg);
assertEquals("def", anotherArg);
throwIfActive();
@ -155,7 +196,7 @@ public class InstrumenterTests extends ESTestCase {
}
@Override
public void checkCtor(Class<?> callerClass, int arg) {
public void checkCtorOverload(Class<?> callerClass, int arg) {
checkCtorIntCallCount++;
assertSame(InstrumenterTests.class, callerClass);
assertEquals(123, arg);
@ -163,206 +204,83 @@ public class InstrumenterTests extends ESTestCase {
}
}
public void testClassIsInstrumented() throws Exception {
var classToInstrument = TestClassToInstrument.class;
@Before
public void resetInstance() {
TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker();
}
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
checkMethod
);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = false;
public void testStaticMethod() throws Exception {
Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)));
// Before checking is active, nothing should throw
callStaticMethod(newClass, "someStaticMethod", 123);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
assertStaticMethod(loader, targetMethod, 123);
// After checking is activated, everything should throw
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertStaticMethodThrows(loader, targetMethod, 123);
}
public void testClassIsNotInstrumentedTwice() throws Exception {
var classToInstrument = TestClassToInstrument.class;
public void testNotInstrumentedTwice() throws Exception {
Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod));
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
checkMethod
);
var instrumenter = createInstrumenter(checkMethods);
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
var internalClassName = Type.getInternalName(classToInstrument);
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
var loader1 = instrumentTestClass(instrumenter);
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes);
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode);
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW_NEW",
instrumentedTwiceBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertStaticMethodThrows(loader2, targetMethod, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
}
public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
var classToInstrument = TestClassToInstrument.class;
public void testMultipleMethods() throws Exception {
Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class);
CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
checkMethod,
methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)),
checkMethod
);
var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2));
var loader = instrumentTestClass(instrumenter);
var instrumenter = createInstrumenter(checkMethods);
InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
var internalClassName = Type.getInternalName(classToInstrument);
byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW_NEW",
instrumentedTwiceBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertStaticMethodThrows(loader, targetMethod1, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123));
assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
assertStaticMethodThrows(loader, targetMethod2, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount);
}
public void testInstrumenterWorksWithOverloads() throws Exception {
var classToInstrument = TestClassToInstrument.class;
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
public void testStaticMethodOverload() throws Exception {
Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class);
var instrumenter = createInstrumenter(
Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2)
);
var loader = instrumentTestClass(instrumenter);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0;
// After checking is activated, everything should throw
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
assertStaticMethodThrows(loader, targetMethod1, 123);
assertStaticMethodThrows(loader, targetMethod2, 123, "abc");
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount);
}
public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
var classToInstrument = TestClassToInstrument.class;
Map<MethodKey, CheckMethod> checkMethods = Map.of(
methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
public void testInstanceMethodOverload() throws Exception {
Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class);
var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod));
var loader = instrumentTestClass(instrumenter);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0;
Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance());
// This overload is not instrumented, so it will not throw
testTargetClass.someMethod(123);
assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount);
}
public void testInstrumenterWorksWithConstructors() throws Exception {
var classToInstrument = TestClassToInstrument.class;
Map<MethodKey, CheckMethod> checkMethods = Map.of(
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
);
var instrumenter = createInstrumenter(checkMethods);
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
classToInstrument.getName() + "_NEW",
newBytecode
);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
assertThat(ex.getCause(), instanceOf(TestException.class));
var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
assertThat(ex2.getCause(), instanceOf(TestException.class));
public void testConstructors() throws Exception {
Constructor<?> ctor1 = TestClassToInstrument.class.getConstructor();
Constructor<?> ctor2 = TestClassToInstrument.class.getConstructor(int.class);
var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2)));
assertCtorThrows(loader, ctor1);
assertCtorThrows(loader, ctor2, 123);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount);
assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount);
}
@ -373,11 +291,107 @@ public class InstrumenterTests extends ESTestCase {
* MethodKey and instrumentationMethod with slightly different signatures (using the common interface
* Testable) which is not what would happen when it's run by the agent.
*/
private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
private static InstrumenterImpl createInstrumenter(Map<String, Executable> methods) throws NoSuchMethodException {
Map<MethodKey, CheckMethod> checkMethods = new HashMap<>();
for (var entry : methods.entrySet()) {
checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue()));
}
String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class);
String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class);
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods);
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
}
private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException {
var clazz = TestClassToInstrument.class;
ClassFileInfo initial = getClassFileInfo(clazz);
byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes());
if (logger.isTraceEnabled()) {
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
}
return new TestLoader(clazz.getName(), newBytecode);
}
private static MethodKey getMethodKey(Executable method) {
logger.info("method key: {}", method.getName());
String methodName = method instanceof Constructor<?> ? "<init>" : method.getName();
return new MethodKey(
Type.getInternalName(method.getDeclaringClass()),
methodName,
Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList()
);
}
private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException {
Set<AccessFlag> flags = targetMethod.accessFlags();
boolean isInstance = flags.contains(AccessFlag.STATIC) == false && targetMethod instanceof Method;
int extraArgs = 1; // caller class
if (isInstance) {
++extraArgs;
}
Class<?>[] targetParameterTypes = targetMethod.getParameterTypes();
Class<?>[] checkParameterTypes = new Class<?>[targetParameterTypes.length + extraArgs];
checkParameterTypes[0] = Class.class;
if (isInstance) {
checkParameterTypes[1] = Testable.class;
}
System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length);
var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes);
return new CheckMethod(
Type.getInternalName(MockEntitlementChecker.class),
checkMethod.getName(),
Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList()
);
}
private static void unwrapInvocationException(InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {
// Sometimes we're expecting this one!
throw n;
} else {
throw new AssertionError(cause);
}
}
/**
* Calling a static method of a dynamically loaded class is significantly more cumbersome
* than calling a virtual method.
*/
static void callStaticMethod(Method method, Object... args) {
try {
method.invoke(null, args);
} catch (InvocationTargetException e) {
unwrapInvocationException(e);
} catch (IllegalAccessException e) {
throw new AssertionError(e);
}
}
private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) {
Method testMethod = loader.getSameMethod(method);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
expectThrows(TestException.class, () -> callStaticMethod(testMethod, args));
}
private void assertStaticMethod(TestLoader loader, Method method, Object... args) {
Method testMethod = loader.getSameMethod(method);
TestEntitlementCheckerHolder.checkerInstance.isActive = false;
callStaticMethod(testMethod, args);
}
private void assertCtorThrows(TestLoader loader, Constructor<?> ctor, Object... args) {
Constructor<?> testCtor = loader.getSameConstructor(ctor);
TestEntitlementCheckerHolder.checkerInstance.isActive = true;
expectThrows(TestException.class, () -> {
try {
testCtor.newInstance(args);
} catch (InvocationTargetException e) {
unwrapInvocationException(e);
} catch (IllegalAccessException | InstantiationException e) {
throw new AssertionError(e);
}
});
}
}

View file

@ -1,20 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.instrumentation.impl;
class TestLoader extends ClassLoader {
TestLoader(ClassLoader parent) {
super(parent);
}
public Class<?> defineClassFromBytes(String name, byte[] bytes) {
return defineClass(name, bytes, 0, bytes.length);
}
}

View file

@ -1,81 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.instrumentation.impl;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.objectweb.asm.Type;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
class TestMethodUtils {
/**
* @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
*/
static MethodKey methodKeyForTarget(Method targetMethod) {
Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
return new MethodKey(
Type.getInternalName(targetMethod.getDeclaringClass()),
targetMethod.getName(),
Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
);
}
static MethodKey methodKeyForConstructor(Class<?> classToInstrument, List<String> params) {
return new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", params);
}
static CheckMethod getCheckMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) throws NoSuchMethodException {
var method = clazz.getMethod(methodName, parameterTypes);
return new CheckMethod(
Type.getInternalName(clazz),
method.getName(),
Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList()
);
}
/**
* Calling a static method of a dynamically loaded class is significantly more cumbersome
* than calling a virtual method.
*/
static void callStaticMethod(Class<?> c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException {
try {
c.getMethod(methodName, int.class).invoke(null, arg);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {
// Sometimes we're expecting this one!
throw n;
} else {
throw new AssertionError(cause);
}
}
}
static void callStaticMethod(Class<?> c, String methodName, int arg1, String arg2) throws NoSuchMethodException,
IllegalAccessException {
try {
c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
if (cause instanceof TestException n) {
// Sometimes we're expecting this one!
throw n;
} else {
throw new AssertionError(cause);
}
}
}
}

View file

@ -13,7 +13,11 @@ import java.net.URL;
import java.net.URLStreamHandlerFactory;
public interface EntitlementChecker {
void check$java_lang_System$exit(Class<?> callerClass, int status);
// Exit the JVM process
void check$$exit(Class<?> callerClass, Runtime runtime, int status);
void check$$halt(Class<?> callerClass, Runtime runtime, int status);
// URLClassLoader ctor
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls);

View file

@ -47,14 +47,21 @@ public class RestEntitlementsCheckAction extends BaseRestHandler {
}
private static final Map<String, CheckAction> checkActions = Map.ofEntries(
entry("system_exit", CheckAction.serverOnly(RestEntitlementsCheckAction::systemExit)),
entry("runtime_exit", CheckAction.serverOnly(RestEntitlementsCheckAction::runtimeExit)),
entry("runtime_halt", CheckAction.serverOnly(RestEntitlementsCheckAction::runtimeHalt)),
entry("create_classloader", CheckAction.serverAndPlugin(RestEntitlementsCheckAction::createClassLoader))
);
@SuppressForbidden(reason = "Specifically testing System.exit")
private static void systemExit() {
logger.info("Calling System.exit(123);");
System.exit(123);
@SuppressForbidden(reason = "Specifically testing Runtime.exit")
private static void runtimeExit() {
logger.info("Calling Runtime.exit;");
Runtime.getRuntime().exit(123);
}
@SuppressForbidden(reason = "Specifically testing Runtime.halt")
private static void runtimeHalt() {
logger.info("Calling Runtime.halt;");
Runtime.getRuntime().halt(123);
}
private static void createClassLoader() {

View file

@ -14,6 +14,7 @@ import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap;
import org.elasticsearch.entitlement.bridge.EntitlementChecker;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
import org.elasticsearch.entitlement.instrumentation.Instrumenter;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.elasticsearch.entitlement.instrumentation.Transformer;
import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker;
@ -64,13 +65,12 @@ public class EntitlementInitialization {
public static void initialize(Instrumentation inst) throws Exception {
manager = initChecker();
Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument(
"org.elasticsearch.entitlement.bridge.EntitlementChecker"
);
Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethods(EntitlementChecker.class);
var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());
inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true);
Instrumenter instrumenter = INSTRUMENTER_FACTORY.newInstrumenter(EntitlementChecker.class, checkMethods);
inst.addTransformer(new Transformer(instrumenter, classesToTransform), true);
// TODO: should we limit this array somehow?
var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new);
inst.retransformClasses(classesToRetransform);

View file

@ -16,7 +16,7 @@ import java.util.Map;
* The SPI service entry point for instrumentation.
*/
public interface InstrumentationService {
Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods);
Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods);
Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException;
Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws IOException;
}

View file

@ -28,7 +28,12 @@ public class ElasticsearchEntitlementChecker implements EntitlementChecker {
}
@Override
public void check$java_lang_System$exit(Class<?> callerClass, int status) {
public void check$$exit(Class<?> callerClass, Runtime runtime, int status) {
policyManager.checkExitVM(callerClass);
}
@Override
public void check$$halt(Class<?> callerClass, Runtime runtime, int status) {
policyManager.checkExitVM(callerClass);
}

View file

@ -19,8 +19,8 @@ public class Java23ElasticsearchEntitlementChecker extends ElasticsearchEntitlem
}
@Override
public void check$java_lang_System$exit(Class<?> callerClass, int status) {
public void check$$exit(Class<?> callerClass, Runtime runtime, int status) {
// TODO: this is just an example, we shouldn't really override a method implemented in the superclass
super.check$java_lang_System$exit(callerClass, status);
super.check$$exit(callerClass, runtime, status);
}
}

View file

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.tools;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.stream.Collectors;
public enum ExternalAccess {
PUBLIC_CLASS,
PUBLIC_METHOD,
PROTECTED_METHOD;
private static final String DELIMITER = ":";
public static String toString(EnumSet<ExternalAccess> externalAccesses) {
return externalAccesses.stream().map(Enum::toString).collect(Collectors.joining(DELIMITER));
}
public static EnumSet<ExternalAccess> fromPermissions(
boolean packageExported,
boolean publicClass,
boolean publicMethod,
boolean protectedMethod
) {
if (publicMethod && protectedMethod) {
throw new IllegalArgumentException();
}
EnumSet<ExternalAccess> externalAccesses = EnumSet.noneOf(ExternalAccess.class);
if (publicMethod) {
externalAccesses.add(ExternalAccess.PUBLIC_METHOD);
} else if (protectedMethod) {
externalAccesses.add(ExternalAccess.PROTECTED_METHOD);
}
if (packageExported && publicClass) {
externalAccesses.add(ExternalAccess.PUBLIC_CLASS);
}
return externalAccesses;
}
public static boolean isExternallyAccessible(EnumSet<ExternalAccess> access) {
return access.contains(ExternalAccess.PUBLIC_CLASS)
&& (access.contains(ExternalAccess.PUBLIC_METHOD) || access.contains(ExternalAccess.PROTECTED_METHOD));
}
public static EnumSet<ExternalAccess> fromString(String accessAsString) {
if ("PUBLIC".equals(accessAsString)) {
return EnumSet.of(ExternalAccess.PUBLIC_CLASS, ExternalAccess.PUBLIC_METHOD);
}
if ("PUBLIC-METHOD".equals(accessAsString)) {
return EnumSet.of(ExternalAccess.PUBLIC_METHOD);
}
if ("PRIVATE".equals(accessAsString)) {
return EnumSet.noneOf(ExternalAccess.class);
}
return EnumSet.copyOf(Arrays.stream(accessAsString.split(DELIMITER)).map(ExternalAccess::valueOf).toList());
}
}

View file

@ -11,16 +11,28 @@ package org.elasticsearch.entitlement.tools;
import java.io.IOException;
import java.lang.module.ModuleDescriptor;
import java.net.URI;
import java.nio.file.FileSystem;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class Utils {
public static Map<String, Set<String>> findModuleExports(FileSystem fs) throws IOException {
private static final Set<String> EXCLUDED_MODULES = Set.of(
"java.desktop",
"jdk.jartool",
"jdk.jdi",
"java.security.jgss",
"jdk.jshell"
);
private static Map<String, Set<String>> findModuleExports(FileSystem fs) throws IOException {
var modulesExports = new HashMap<String, Set<String>>();
try (var stream = Files.walk(fs.getPath("modules"))) {
stream.filter(p -> p.getFileName().toString().equals("module-info.class")).forEach(x -> {
@ -42,4 +54,27 @@ public class Utils {
return modulesExports;
}
public interface JdkModuleConsumer {
void accept(String moduleName, List<Path> moduleClasses, Set<String> moduleExports);
}
public static void walkJdkModules(JdkModuleConsumer c) throws IOException {
FileSystem fs = FileSystems.getFileSystem(URI.create("jrt:/"));
var moduleExports = Utils.findModuleExports(fs);
try (var stream = Files.walk(fs.getPath("modules"))) {
var modules = stream.filter(x -> x.toString().endsWith(".class"))
.collect(Collectors.groupingBy(x -> x.subpath(1, 2).toString()));
for (var kv : modules.entrySet()) {
var moduleName = kv.getKey();
if (Utils.EXCLUDED_MODULES.contains(moduleName) == false) {
var thisModuleExports = moduleExports.get(moduleName);
c.accept(moduleName, kv.getValue(), thisModuleExports);
}
}
}
}
}

View file

@ -0,0 +1,50 @@
This tool scans the JDK on which it is running. It takes a list of methods (compatible with the output of the `securitymanager-scanner` tool), and looks for the "public surface" of these methods (i.e. any class/method accessible from regular Java code that calls into the original list, directly or transitively).
It acts basically as a recursive "Find Usages" in Intellij, stopping at the first fully accessible point (public method on a public class).
The tool scans every method in every class inside the same java module; e.g.
if you have a private method `File#normalizedList`, it will scan `java.base` to find
public methods like `File#list(String)`, `File#list(FilenameFilter, String)` and
`File#listFiles(File)`.
The tool considers implemented interfaces (directly); e.g. if we're looking at a
method `C.m`, where `C implements I`, it will look for calls to `I.m`. It will
also consider (indirectly) calls to `S.m` (where `S` is a supertype of `C`), as
it treats calls to `super` in `S.m` as regular calls (e.g. `example() -> S.m() -> C.m()`).
In order to run the tool, use:
```shell
./gradlew :libs:entitlement:tools:public-callers-finder:run <input-file> [<bubble-up-from-public>]
```
Where `input-file` is a CSV file (columns separated by `TAB`) that contains the following columns:
Module name
1. unused
2. unused
3. unused
4. Fully qualified class name (ASM style, with `/` separators)
5. Method name
6. Method descriptor (ASM signature)
7. Visibility (PUBLIC/PUBLIC-METHOD/PRIVATE)
And `bubble-up-from-public` is a boolean (`true|false`) indicating if the code should stop at the first public method (`false`: default, recommended) or continue to find usages recursively even after reaching the "public surface".
The output of the tool is another CSV file, with one line for each entry-point, columns separated by `TAB`
1. Module name
2. File name (from source root)
3. Line number
4. Fully qualified class name (ASM style, with `/` separators)
5. Method name
6. Method descriptor (ASM signature)
7. Visibility (PUBLIC/PUBLIC-METHOD/PRIVATE)
8. Original caller Module name
9. Original caller Class name (ASM style, with `/` separators)
10. Original caller Method name
11. Original caller Visibility
Examples:
```
java.base DeleteOnExitHook.java 50 java/io/DeleteOnExitHook$1 run ()V PUBLIC java.base java/io/File delete PUBLIC
java.base ZipFile.java 254 java/util/zip/ZipFile <init> (Ljava/io/File;ILjava/nio/charset/Charset;)V PUBLIC java.base java/io/File delete PUBLIC
java.logging FileHandler.java 279 java/util/logging/FileHandler <init> ()V PUBLIC java.base java/io/File delete PUBLIC
```

View file

@ -0,0 +1,61 @@
plugins {
id 'application'
}
apply plugin: 'elasticsearch.build'
apply plugin: 'elasticsearch.publish'
tasks.named("dependencyLicenses").configure {
mapping from: /asm-.*/, to: 'asm'
}
group = 'org.elasticsearch.entitlement.tools'
ext {
javaMainClass = "org.elasticsearch.entitlement.tools.publiccallersfinder.Main"
}
application {
mainClass.set(javaMainClass)
applicationDefaultJvmArgs = [
'--add-exports', 'java.base/sun.security.util=ALL-UNNAMED',
'--add-opens', 'java.base/java.lang=ALL-UNNAMED',
'--add-opens', 'java.base/java.net=ALL-UNNAMED',
'--add-opens', 'java.base/java.net.spi=ALL-UNNAMED',
'--add-opens', 'java.base/java.util.concurrent=ALL-UNNAMED',
'--add-opens', 'java.base/javax.crypto=ALL-UNNAMED',
'--add-opens', 'java.base/javax.security.auth=ALL-UNNAMED',
'--add-opens', 'java.base/jdk.internal.logger=ALL-UNNAMED',
'--add-opens', 'java.base/sun.nio.ch=ALL-UNNAMED',
'--add-opens', 'jdk.management.jfr/jdk.management.jfr=ALL-UNNAMED',
'--add-opens', 'java.logging/java.util.logging=ALL-UNNAMED',
'--add-opens', 'java.logging/sun.util.logging.internal=ALL-UNNAMED',
'--add-opens', 'java.naming/javax.naming.ldap.spi=ALL-UNNAMED',
'--add-opens', 'java.rmi/sun.rmi.runtime=ALL-UNNAMED',
'--add-opens', 'jdk.dynalink/jdk.dynalink=ALL-UNNAMED',
'--add-opens', 'jdk.dynalink/jdk.dynalink.linker=ALL-UNNAMED',
'--add-opens', 'java.desktop/sun.awt=ALL-UNNAMED',
'--add-opens', 'java.sql.rowset/javax.sql.rowset.spi=ALL-UNNAMED',
'--add-opens', 'java.sql/java.sql=ALL-UNNAMED',
'--add-opens', 'java.xml.crypto/com.sun.org.apache.xml.internal.security.utils=ALL-UNNAMED'
]
}
repositories {
mavenCentral()
}
dependencies {
compileOnly(project(':libs:core'))
implementation 'org.ow2.asm:asm:9.7.1'
implementation 'org.ow2.asm:asm-util:9.7.1'
implementation(project(':libs:entitlement:tools:common'))
}
tasks.named('forbiddenApisMain').configure {
replaceSignatureFiles 'jdk-signatures'
}
tasks.named("thirdPartyAudit").configure {
ignoreMissingClasses()
}

View file

@ -0,0 +1,26 @@
Copyright (c) 2012 France Télécom
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holders nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.tools.publiccallersfinder;
import org.elasticsearch.entitlement.tools.ExternalAccess;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import java.lang.constant.ClassDesc;
import java.lang.reflect.AccessFlag;
import java.util.EnumSet;
import java.util.Set;
import static org.objectweb.asm.Opcodes.ACC_PROTECTED;
import static org.objectweb.asm.Opcodes.ACC_PUBLIC;
import static org.objectweb.asm.Opcodes.ASM9;
class FindUsagesClassVisitor extends ClassVisitor {
private int classAccess;
private boolean accessibleViaInterfaces;
record MethodDescriptor(String className, String methodName, String methodDescriptor) {}
record EntryPoint(
String moduleName,
String source,
int line,
String className,
String methodName,
String methodDescriptor,
EnumSet<ExternalAccess> access
) {}
interface CallerConsumer {
void accept(String source, int line, String className, String methodName, String methodDescriptor, EnumSet<ExternalAccess> access);
}
private final Set<String> moduleExports;
private final MethodDescriptor methodToFind;
private final CallerConsumer callers;
private String className;
private String source;
protected FindUsagesClassVisitor(Set<String> moduleExports, MethodDescriptor methodToFind, CallerConsumer callers) {
super(ASM9);
this.moduleExports = moduleExports;
this.methodToFind = methodToFind;
this.callers = callers;
}
@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
this.className = name;
this.classAccess = access;
if (interfaces.length > 0) {
this.accessibleViaInterfaces = findAccessibility(interfaces, moduleExports);
}
}
private static boolean findAccessibility(String[] interfaces, Set<String> moduleExports) {
var accessibleViaInterfaces = false;
for (var interfaceName : interfaces) {
if (moduleExports.contains(getPackageName(interfaceName))) {
var interfaceType = Type.getObjectType(interfaceName);
try {
var clazz = Class.forName(interfaceType.getClassName());
if (clazz.accessFlags().contains(AccessFlag.PUBLIC)) {
accessibleViaInterfaces = true;
}
} catch (ClassNotFoundException ignored) {}
}
}
return accessibleViaInterfaces;
}
@Override
public void visitSource(String source, String debug) {
super.visitSource(source, debug);
this.source = source;
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
return new FindUsagesMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions), name, descriptor, access);
}
private static String getPackageName(String className) {
return ClassDesc.ofInternalName(className).packageName();
}
private class FindUsagesMethodVisitor extends MethodVisitor {
private final String methodName;
private int line;
private final String methodDescriptor;
private final int methodAccess;
protected FindUsagesMethodVisitor(MethodVisitor mv, String methodName, String methodDescriptor, int methodAccess) {
super(ASM9, mv);
this.methodName = methodName;
this.methodDescriptor = methodDescriptor;
this.methodAccess = methodAccess;
}
@Override
public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
if (methodToFind.className.equals(owner)) {
if (methodToFind.methodName.equals(name)) {
if (methodToFind.methodDescriptor == null || methodToFind.methodDescriptor.equals(descriptor)) {
EnumSet<ExternalAccess> externalAccess = ExternalAccess.fromPermissions(
moduleExports.contains(getPackageName(className)),
accessibleViaInterfaces || (classAccess & ACC_PUBLIC) != 0,
(methodAccess & ACC_PUBLIC) != 0,
(methodAccess & ACC_PROTECTED) != 0
);
callers.accept(source, line, className, methodName, methodDescriptor, externalAccess);
}
}
}
}
@Override
public void visitLineNumber(int line, Label start) {
super.visitLineNumber(line, start);
this.line = line;
}
}
}

View file

@ -0,0 +1,197 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.tools.publiccallersfinder;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.entitlement.tools.ExternalAccess;
import org.elasticsearch.entitlement.tools.Utils;
import org.objectweb.asm.ClassReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class Main {
private static final String SEPARATOR = "\t";
record CallChain(FindUsagesClassVisitor.EntryPoint entryPoint, CallChain next) {}
interface UsageConsumer {
void usageFound(CallChain originalEntryPoint, CallChain newMethod);
}
private static void findTransitiveUsages(
Collection<CallChain> firstLevelCallers,
List<Path> classesToScan,
Set<String> moduleExports,
boolean bubbleUpFromPublic,
UsageConsumer usageConsumer
) {
for (var caller : firstLevelCallers) {
var methodsToCheck = new ArrayDeque<>(Set.of(caller));
var methodsSeen = new HashSet<FindUsagesClassVisitor.EntryPoint>();
while (methodsToCheck.isEmpty() == false) {
var methodToCheck = methodsToCheck.removeFirst();
var m = methodToCheck.entryPoint();
var visitor2 = new FindUsagesClassVisitor(
moduleExports,
new FindUsagesClassVisitor.MethodDescriptor(m.className(), m.methodName(), m.methodDescriptor()),
(source, line, className, methodName, methodDescriptor, access) -> {
var newMethod = new CallChain(
new FindUsagesClassVisitor.EntryPoint(
m.moduleName(),
source,
line,
className,
methodName,
methodDescriptor,
access
),
methodToCheck
);
var notSeenBefore = methodsSeen.add(newMethod.entryPoint());
if (notSeenBefore) {
if (ExternalAccess.isExternallyAccessible(access)) {
usageConsumer.usageFound(caller.next(), newMethod);
}
if (access.contains(ExternalAccess.PUBLIC_METHOD) == false || bubbleUpFromPublic) {
methodsToCheck.add(newMethod);
}
}
}
);
for (var classFile : classesToScan) {
try {
ClassReader cr = new ClassReader(Files.newInputStream(classFile));
cr.accept(visitor2, 0);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
}
private static void identifyTopLevelEntryPoints(
FindUsagesClassVisitor.MethodDescriptor methodToFind,
String methodToFindModule,
EnumSet<ExternalAccess> methodToFindAccess,
boolean bubbleUpFromPublic
) throws IOException {
Utils.walkJdkModules((moduleName, moduleClasses, moduleExports) -> {
var originalCallers = new ArrayList<CallChain>();
var visitor = new FindUsagesClassVisitor(
moduleExports,
methodToFind,
(source, line, className, methodName, methodDescriptor, access) -> originalCallers.add(
new CallChain(
new FindUsagesClassVisitor.EntryPoint(moduleName, source, line, className, methodName, methodDescriptor, access),
new CallChain(
new FindUsagesClassVisitor.EntryPoint(
methodToFindModule,
"",
0,
methodToFind.className(),
methodToFind.methodName(),
methodToFind.methodDescriptor(),
methodToFindAccess
),
null
)
)
)
);
for (var classFile : moduleClasses) {
try {
ClassReader cr = new ClassReader(Files.newInputStream(classFile));
cr.accept(visitor, 0);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
originalCallers.stream().filter(c -> ExternalAccess.isExternallyAccessible(c.entryPoint().access())).forEach(c -> {
var originalCaller = c.next();
printRow(getEntryPointString(c.entryPoint().moduleName(), c.entryPoint()), getOriginalEntryPointString(originalCaller));
});
var firstLevelCallers = bubbleUpFromPublic ? originalCallers : originalCallers.stream().filter(Main::isNotFullyPublic).toList();
if (firstLevelCallers.isEmpty() == false) {
findTransitiveUsages(
firstLevelCallers,
moduleClasses,
moduleExports,
bubbleUpFromPublic,
(originalEntryPoint, newMethod) -> printRow(
getEntryPointString(moduleName, newMethod.entryPoint()),
getOriginalEntryPointString(originalEntryPoint)
)
);
}
});
}
private static boolean isNotFullyPublic(CallChain c) {
return (c.entryPoint().access().contains(ExternalAccess.PUBLIC_CLASS)
&& c.entryPoint().access().contains(ExternalAccess.PUBLIC_METHOD)) == false;
}
@SuppressForbidden(reason = "This tool prints the CSV to stdout")
private static void printRow(String entryPointString, String originalEntryPoint) {
System.out.println(entryPointString + SEPARATOR + originalEntryPoint);
}
private static String getEntryPointString(String moduleName, FindUsagesClassVisitor.EntryPoint e) {
return moduleName + SEPARATOR + e.source() + SEPARATOR + e.line() + SEPARATOR + e.className() + SEPARATOR + e.methodName()
+ SEPARATOR + e.methodDescriptor() + SEPARATOR + ExternalAccess.toString(e.access());
}
private static String getOriginalEntryPointString(CallChain originalCallChain) {
return originalCallChain.entryPoint().moduleName() + SEPARATOR + originalCallChain.entryPoint().className() + SEPARATOR
+ originalCallChain.entryPoint().methodName() + SEPARATOR + ExternalAccess.toString(originalCallChain.entryPoint().access());
}
interface MethodDescriptorConsumer {
void accept(FindUsagesClassVisitor.MethodDescriptor methodDescriptor, String moduleName, EnumSet<ExternalAccess> access)
throws IOException;
}
private static void parseCsv(Path csvPath, MethodDescriptorConsumer methodConsumer) throws IOException {
var lines = Files.readAllLines(csvPath);
for (var l : lines) {
var tokens = l.split(SEPARATOR);
var moduleName = tokens[0];
var className = tokens[3];
var methodName = tokens[4];
var methodDescriptor = tokens[5];
var access = ExternalAccess.fromString(tokens[6]);
methodConsumer.accept(new FindUsagesClassVisitor.MethodDescriptor(className, methodName, methodDescriptor), moduleName, access);
}
}
public static void main(String[] args) throws IOException {
var csvFilePath = Path.of(args[0]);
boolean bubbleUpFromPublic = args.length >= 2 && Boolean.parseBoolean(args[1]);
parseCsv(csvFilePath, (method, module, access) -> identifyTopLevelEntryPoints(method, module, access, bubbleUpFromPublic));
}
}

View file

@ -10,47 +10,35 @@
package org.elasticsearch.entitlement.tools.securitymanager.scanner;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.entitlement.tools.ExternalAccess;
import org.elasticsearch.entitlement.tools.Utils;
import org.objectweb.asm.ClassReader;
import java.io.IOException;
import java.net.URI;
import java.nio.file.FileSystem;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
public class Main {
static final Set<String> excludedModules = Set.of("java.desktop");
private static void identifySMChecksEntryPoints() throws IOException {
FileSystem fs = FileSystems.getFileSystem(URI.create("jrt:/"));
var moduleExports = Utils.findModuleExports(fs);
var callers = new HashMap<String, List<SecurityCheckClassVisitor.CallerInfo>>();
var visitor = new SecurityCheckClassVisitor(callers);
try (var stream = Files.walk(fs.getPath("modules"))) {
stream.filter(x -> x.toString().endsWith(".class")).forEach(x -> {
var moduleName = x.subpath(1, 2).toString();
if (excludedModules.contains(moduleName) == false) {
try {
ClassReader cr = new ClassReader(Files.newInputStream(x));
visitor.setCurrentModule(moduleName, moduleExports.get(moduleName));
var path = x.getNameCount() > 3 ? x.subpath(2, x.getNameCount() - 1).toString() : "";
visitor.setCurrentSourcePath(path);
cr.accept(visitor, 0);
} catch (IOException e) {
throw new RuntimeException(e);
}
Utils.walkJdkModules((moduleName, moduleClasses, moduleExports) -> {
for (var classFile : moduleClasses) {
try {
ClassReader cr = new ClassReader(Files.newInputStream(classFile));
visitor.setCurrentModule(moduleName, moduleExports);
var path = classFile.getNameCount() > 3 ? classFile.subpath(2, classFile.getNameCount() - 1).toString() : "";
visitor.setCurrentSourcePath(path);
cr.accept(visitor, 0);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
}
});
printToStdout(callers);
}
@ -68,16 +56,8 @@ public class Main {
private static String toString(String calleeName, SecurityCheckClassVisitor.CallerInfo callerInfo) {
var s = callerInfo.moduleName() + SEPARATOR + callerInfo.source() + SEPARATOR + callerInfo.line() + SEPARATOR + callerInfo
.className() + SEPARATOR + callerInfo.methodName() + SEPARATOR + callerInfo.methodDescriptor() + SEPARATOR;
if (callerInfo.externalAccess().contains(SecurityCheckClassVisitor.ExternalAccess.METHOD)
&& callerInfo.externalAccess().contains(SecurityCheckClassVisitor.ExternalAccess.CLASS)) {
s += "PUBLIC";
} else if (callerInfo.externalAccess().contains(SecurityCheckClassVisitor.ExternalAccess.METHOD)) {
s += "PUBLIC-METHOD";
} else {
s += "PRIVATE";
}
.className() + SEPARATOR + callerInfo.methodName() + SEPARATOR + callerInfo.methodDescriptor() + SEPARATOR + ExternalAccess
.toString(callerInfo.externalAccess());
if (callerInfo.runtimePermissionType() != null) {
s += SEPARATOR + callerInfo.runtimePermissionType();

View file

@ -10,6 +10,7 @@
package org.elasticsearch.entitlement.tools.securitymanager.scanner;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.entitlement.tools.ExternalAccess;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
@ -27,6 +28,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.objectweb.asm.Opcodes.ACC_PROTECTED;
import static org.objectweb.asm.Opcodes.ACC_PUBLIC;
import static org.objectweb.asm.Opcodes.ASM9;
import static org.objectweb.asm.Opcodes.GETSTATIC;
@ -42,11 +44,6 @@ class SecurityCheckClassVisitor extends ClassVisitor {
static final String SECURITY_MANAGER_INTERNAL_NAME = "java/lang/SecurityManager";
static final Set<String> excludedClasses = Set.of(SECURITY_MANAGER_INTERNAL_NAME);
enum ExternalAccess {
CLASS,
METHOD
}
record CallerInfo(
String moduleName,
String source,
@ -208,15 +205,12 @@ class SecurityCheckClassVisitor extends ClassVisitor {
|| opcode == INVOKEDYNAMIC) {
if (SECURITY_MANAGER_INTERNAL_NAME.equals(owner)) {
EnumSet<ExternalAccess> externalAccesses = EnumSet.noneOf(ExternalAccess.class);
if (moduleExports.contains(getPackageName(className))) {
if ((methodAccess & ACC_PUBLIC) != 0) {
externalAccesses.add(ExternalAccess.METHOD);
}
if ((classAccess & ACC_PUBLIC) != 0) {
externalAccesses.add(ExternalAccess.CLASS);
}
}
EnumSet<ExternalAccess> externalAccesses = ExternalAccess.fromPermissions(
moduleExports.contains(getPackageName(className)),
(classAccess & ACC_PUBLIC) != 0,
(methodAccess & ACC_PUBLIC) != 0,
(methodAccess & ACC_PROTECTED) != 0
);
if (name.equals("checkPermission")) {
var callers = callerInfoByMethod.computeIfAbsent(name, ignored -> new ArrayList<>());

View file

@ -83,10 +83,15 @@ public class SpatialEnvelopeVisitor implements GeometryVisitor<Boolean, RuntimeE
return Optional.empty();
}
public enum WrapLongitude {
NO_WRAP,
WRAP
}
/**
* Determine the BBOX assuming the CRS is geographic (eg WGS84) and optionally wrapping the longitude around the dateline.
*/
public static Optional<Rectangle> visitGeo(Geometry geometry, boolean wrapLongitude) {
public static Optional<Rectangle> visitGeo(Geometry geometry, WrapLongitude wrapLongitude) {
var visitor = new SpatialEnvelopeVisitor(new GeoPointVisitor(wrapLongitude));
if (geometry.visit(visitor)) {
return Optional.of(visitor.getResult());
@ -181,40 +186,16 @@ public class SpatialEnvelopeVisitor implements GeometryVisitor<Boolean, RuntimeE
* </ul>
*/
public static class GeoPointVisitor implements PointVisitor {
private double minY = Double.POSITIVE_INFINITY;
private double maxY = Double.NEGATIVE_INFINITY;
private double minNegX = Double.POSITIVE_INFINITY;
private double maxNegX = Double.NEGATIVE_INFINITY;
private double minPosX = Double.POSITIVE_INFINITY;
private double maxPosX = Double.NEGATIVE_INFINITY;
protected double minY = Double.POSITIVE_INFINITY;
protected double maxY = Double.NEGATIVE_INFINITY;
protected double minNegX = Double.POSITIVE_INFINITY;
protected double maxNegX = Double.NEGATIVE_INFINITY;
protected double minPosX = Double.POSITIVE_INFINITY;
protected double maxPosX = Double.NEGATIVE_INFINITY;
public double getMinY() {
return minY;
}
private final WrapLongitude wrapLongitude;
public double getMaxY() {
return maxY;
}
public double getMinNegX() {
return minNegX;
}
public double getMaxNegX() {
return maxNegX;
}
public double getMinPosX() {
return minPosX;
}
public double getMaxPosX() {
return maxPosX;
}
private final boolean wrapLongitude;
public GeoPointVisitor(boolean wrapLongitude) {
public GeoPointVisitor(WrapLongitude wrapLongitude) {
this.wrapLongitude = wrapLongitude;
}
@ -253,32 +234,35 @@ public class SpatialEnvelopeVisitor implements GeometryVisitor<Boolean, RuntimeE
return getResult(minNegX, minPosX, maxNegX, maxPosX, maxY, minY, wrapLongitude);
}
private static Rectangle getResult(
protected static Rectangle getResult(
double minNegX,
double minPosX,
double maxNegX,
double maxPosX,
double maxY,
double minY,
boolean wrapLongitude
WrapLongitude wrapLongitude
) {
assert Double.isFinite(maxY);
if (Double.isInfinite(minPosX)) {
return new Rectangle(minNegX, maxNegX, maxY, minY);
} else if (Double.isInfinite(minNegX)) {
return new Rectangle(minPosX, maxPosX, maxY, minY);
} else if (wrapLongitude) {
double unwrappedWidth = maxPosX - minNegX;
double wrappedWidth = (180 - minPosX) - (-180 - maxNegX);
if (unwrappedWidth <= wrappedWidth) {
return new Rectangle(minNegX, maxPosX, maxY, minY);
} else {
return new Rectangle(minPosX, maxNegX, maxY, minY);
}
} else {
return new Rectangle(minNegX, maxPosX, maxY, minY);
return switch (wrapLongitude) {
case NO_WRAP -> new Rectangle(minNegX, maxPosX, maxY, minY);
case WRAP -> maybeWrap(minNegX, minPosX, maxNegX, maxPosX, maxY, minY);
};
}
}
private static Rectangle maybeWrap(double minNegX, double minPosX, double maxNegX, double maxPosX, double maxY, double minY) {
double unwrappedWidth = maxPosX - minNegX;
double wrappedWidth = 360 + maxNegX - minPosX;
return unwrappedWidth <= wrappedWidth
? new Rectangle(minNegX, maxPosX, maxY, minY)
: new Rectangle(minPosX, maxNegX, maxY, minY);
}
}
private boolean isValid() {

View file

@ -13,6 +13,7 @@ import org.elasticsearch.geo.GeometryTestUtils;
import org.elasticsearch.geo.ShapeTestUtils;
import org.elasticsearch.geometry.Point;
import org.elasticsearch.geometry.Rectangle;
import org.elasticsearch.geometry.utils.SpatialEnvelopeVisitor.WrapLongitude;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.equalTo;
@ -36,7 +37,7 @@ public class SpatialEnvelopeVisitorTests extends ESTestCase {
public void testVisitGeoShapeNoWrap() {
for (int i = 0; i < 1000; i++) {
var geometry = GeometryTestUtils.randomGeometryWithoutCircle(0, false);
var bbox = SpatialEnvelopeVisitor.visitGeo(geometry, false);
var bbox = SpatialEnvelopeVisitor.visitGeo(geometry, WrapLongitude.NO_WRAP);
assertNotNull(bbox);
assertTrue(i + ": " + geometry, bbox.isPresent());
var result = bbox.get();
@ -48,7 +49,8 @@ public class SpatialEnvelopeVisitorTests extends ESTestCase {
public void testVisitGeoShapeWrap() {
for (int i = 0; i < 1000; i++) {
var geometry = GeometryTestUtils.randomGeometryWithoutCircle(0, true);
var bbox = SpatialEnvelopeVisitor.visitGeo(geometry, false);
// TODO this should be WRAP instead
var bbox = SpatialEnvelopeVisitor.visitGeo(geometry, WrapLongitude.NO_WRAP);
assertNotNull(bbox);
assertTrue(i + ": " + geometry, bbox.isPresent());
var result = bbox.get();
@ -81,7 +83,7 @@ public class SpatialEnvelopeVisitorTests extends ESTestCase {
}
public void testVisitGeoPointsNoWrapping() {
var visitor = new SpatialEnvelopeVisitor(new SpatialEnvelopeVisitor.GeoPointVisitor(false));
var visitor = new SpatialEnvelopeVisitor(new SpatialEnvelopeVisitor.GeoPointVisitor(WrapLongitude.NO_WRAP));
double minY = Double.MAX_VALUE;
double maxY = -Double.MAX_VALUE;
double minX = Double.MAX_VALUE;
@ -103,7 +105,7 @@ public class SpatialEnvelopeVisitorTests extends ESTestCase {
}
public void testVisitGeoPointsWrapping() {
var visitor = new SpatialEnvelopeVisitor(new SpatialEnvelopeVisitor.GeoPointVisitor(true));
var visitor = new SpatialEnvelopeVisitor(new SpatialEnvelopeVisitor.GeoPointVisitor(WrapLongitude.WRAP));
double minY = Double.POSITIVE_INFINITY;
double maxY = Double.NEGATIVE_INFINITY;
double minNegX = Double.POSITIVE_INFINITY;
@ -145,7 +147,7 @@ public class SpatialEnvelopeVisitorTests extends ESTestCase {
}
public void testWillCrossDateline() {
var visitor = new SpatialEnvelopeVisitor(new SpatialEnvelopeVisitor.GeoPointVisitor(true));
var visitor = new SpatialEnvelopeVisitor(new SpatialEnvelopeVisitor.GeoPointVisitor(WrapLongitude.WRAP));
visitor.visit(new Point(-90.0, 0.0));
visitor.visit(new Point(90.0, 0.0));
assertCrossesDateline(visitor, false);