[Entitlements] Network access checks on Sockets (#120093)

This commit is contained in:
Lorenzo Dematté 2025-01-15 22:01:56 +01:00 committed by GitHub
parent 1448f12d23
commit 1848d6bb93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 356 additions and 31 deletions

View file

@ -20,8 +20,11 @@ import java.net.FileNameMap;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.MulticastSocket; import java.net.MulticastSocket;
import java.net.NetworkInterface; import java.net.NetworkInterface;
import java.net.Proxy;
import java.net.ProxySelector; import java.net.ProxySelector;
import java.net.ResponseCache; import java.net.ResponseCache;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.net.SocketImplFactory; import java.net.SocketImplFactory;
import java.net.URL; import java.net.URL;
@ -215,4 +218,40 @@ public interface EntitlementChecker {
void check$java_net_MulticastSocket$leaveGroup(Class<?> callerClass, MulticastSocket that, SocketAddress addr, NetworkInterface ni); void check$java_net_MulticastSocket$leaveGroup(Class<?> callerClass, MulticastSocket that, SocketAddress addr, NetworkInterface ni);
void check$java_net_MulticastSocket$send(Class<?> callerClass, MulticastSocket that, DatagramPacket p, byte ttl); void check$java_net_MulticastSocket$send(Class<?> callerClass, MulticastSocket that, DatagramPacket p, byte ttl);
// Binding/connecting ctor
void check$java_net_ServerSocket$(Class<?> callerClass, int port);
void check$java_net_ServerSocket$(Class<?> callerClass, int port, int backlog);
void check$java_net_ServerSocket$(Class<?> callerClass, int port, int backlog, InetAddress bindAddr);
void check$java_net_ServerSocket$accept(Class<?> callerClass, ServerSocket that);
void check$java_net_ServerSocket$implAccept(Class<?> callerClass, ServerSocket that, Socket s);
void check$java_net_ServerSocket$bind(Class<?> callerClass, ServerSocket that, SocketAddress endpoint);
void check$java_net_ServerSocket$bind(Class<?> callerClass, ServerSocket that, SocketAddress endpoint, int backlog);
// Binding/connecting ctors
void check$java_net_Socket$(Class<?> callerClass, Proxy proxy);
void check$java_net_Socket$(Class<?> callerClass, String host, int port);
void check$java_net_Socket$(Class<?> callerClass, InetAddress address, int port);
void check$java_net_Socket$(Class<?> callerClass, String host, int port, InetAddress localAddr, int localPort);
void check$java_net_Socket$(Class<?> callerClass, InetAddress address, int port, InetAddress localAddr, int localPort);
void check$java_net_Socket$(Class<?> callerClass, String host, int port, boolean stream);
void check$java_net_Socket$(Class<?> callerClass, InetAddress host, int port, boolean stream);
void check$java_net_Socket$bind(Class<?> callerClass, Socket that, SocketAddress endpoint);
void check$java_net_Socket$connect(Class<?> callerClass, Socket that, SocketAddress endpoint);
void check$java_net_Socket$connect(Class<?> callerClass, Socket that, SocketAddress endpoint, int backlog);
} }

View file

@ -10,14 +10,18 @@
package org.elasticsearch.entitlement.qa.common; package org.elasticsearch.entitlement.qa.common;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.DatagramPacket; import java.net.DatagramPacket;
import java.net.DatagramSocket; import java.net.DatagramSocket;
import java.net.DatagramSocketImpl; import java.net.DatagramSocketImpl;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.NetworkInterface; import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.net.SocketException; import java.net.SocketException;
import java.net.SocketImpl;
import java.security.cert.Certificate; import java.security.cert.Certificate;
import java.text.BreakIterator; import java.text.BreakIterator;
import java.text.Collator; import java.text.Collator;
@ -297,6 +301,81 @@ class DummyImplementations {
} }
} }
private static class DummySocketImpl extends SocketImpl {
@Override
protected void create(boolean stream) {}
@Override
protected void connect(String host, int port) {}
@Override
protected void connect(InetAddress address, int port) {}
@Override
protected void connect(SocketAddress address, int timeout) {}
@Override
protected void bind(InetAddress host, int port) {}
@Override
protected void listen(int backlog) {}
@Override
protected void accept(SocketImpl s) {}
@Override
protected InputStream getInputStream() {
return null;
}
@Override
protected OutputStream getOutputStream() {
return null;
}
@Override
protected int available() {
return 0;
}
@Override
protected void close() {}
@Override
protected void sendUrgentData(int data) {}
@Override
public void setOption(int optID, Object value) {}
@Override
public Object getOption(int optID) {
return null;
}
}
static class DummySocket extends Socket {
DummySocket() throws SocketException {
super(new DummySocketImpl());
}
}
static class DummyServerSocket extends ServerSocket {
DummyServerSocket() {
super(new DummySocketImpl());
}
}
static class DummyBoundServerSocket extends ServerSocket {
DummyBoundServerSocket() {
super(new DummySocketImpl());
}
@Override
public boolean isBound() {
return true;
}
}
static class DummySSLSocketFactory extends SSLSocketFactory { static class DummySSLSocketFactory extends SSLSocketFactory {
@Override @Override
public Socket createSocket(String host, int port) { public Socket createSocket(String host, int port) {

View file

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.entitlement.qa.common;
import org.elasticsearch.core.SuppressForbidden;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.ServerSocket;
import java.net.Socket;
class NetworkAccessCheckActions {
static void serverSocketAccept() throws IOException {
try (ServerSocket socket = new DummyImplementations.DummyBoundServerSocket()) {
try {
socket.accept();
} catch (IOException e) {
// Our dummy socket cannot accept connections unless we tell the JDK how to create a socket for it.
// But Socket.setSocketImplFactory(); is one of the methods we always forbid, so we cannot use it.
// Still, we can check accept is called (allowed/denied), we don't care if it fails later for this
// known reason.
assert e.getMessage().contains("client socket implementation factory not set");
}
}
}
static void serverSocketBind() throws IOException {
try (ServerSocket socket = new DummyImplementations.DummyServerSocket()) {
socket.bind(null);
}
}
@SuppressForbidden(reason = "Testing entitlement check on forbidden action")
static void createSocketWithProxy() throws IOException {
try (Socket socket = new Socket(new Proxy(Proxy.Type.HTTP, new InetSocketAddress(0)))) {
assert socket.isBound() == false;
}
}
static void socketBind() throws IOException {
try (Socket socket = new DummyImplementations.DummySocket()) {
socket.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
}
}
@SuppressForbidden(reason = "Testing entitlement check on forbidden action")
static void socketConnect() throws IOException {
try (Socket socket = new DummyImplementations.DummySocket()) {
socket.connect(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
}
}
}

View file

@ -149,7 +149,13 @@ public class RestEntitlementsCheckAction extends BaseRestHandler {
entry("datagram_socket_send", forPlugins(RestEntitlementsCheckAction::sendDatagramSocket)), entry("datagram_socket_send", forPlugins(RestEntitlementsCheckAction::sendDatagramSocket)),
entry("datagram_socket_receive", forPlugins(RestEntitlementsCheckAction::receiveDatagramSocket)), entry("datagram_socket_receive", forPlugins(RestEntitlementsCheckAction::receiveDatagramSocket)),
entry("datagram_socket_join_group", forPlugins(RestEntitlementsCheckAction::joinGroupDatagramSocket)), entry("datagram_socket_join_group", forPlugins(RestEntitlementsCheckAction::joinGroupDatagramSocket)),
entry("datagram_socket_leave_group", forPlugins(RestEntitlementsCheckAction::leaveGroupDatagramSocket)) entry("datagram_socket_leave_group", forPlugins(RestEntitlementsCheckAction::leaveGroupDatagramSocket)),
entry("create_socket_with_proxy", forPlugins(NetworkAccessCheckActions::createSocketWithProxy)),
entry("socket_bind", forPlugins(NetworkAccessCheckActions::socketBind)),
entry("socket_connect", forPlugins(NetworkAccessCheckActions::socketConnect)),
entry("server_socket_bind", forPlugins(NetworkAccessCheckActions::serverSocketBind)),
entry("server_socket_accept", forPlugins(NetworkAccessCheckActions::serverSocketAccept))
); );
private static void createURLStreamHandlerProvider() { private static void createURLStreamHandlerProvider() {

View file

@ -22,6 +22,7 @@ import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker
import org.elasticsearch.entitlement.runtime.policy.CreateClassLoaderEntitlement; import org.elasticsearch.entitlement.runtime.policy.CreateClassLoaderEntitlement;
import org.elasticsearch.entitlement.runtime.policy.Entitlement; import org.elasticsearch.entitlement.runtime.policy.Entitlement;
import org.elasticsearch.entitlement.runtime.policy.ExitVMEntitlement; import org.elasticsearch.entitlement.runtime.policy.ExitVMEntitlement;
import org.elasticsearch.entitlement.runtime.policy.NetworkEntitlement;
import org.elasticsearch.entitlement.runtime.policy.Policy; import org.elasticsearch.entitlement.runtime.policy.Policy;
import org.elasticsearch.entitlement.runtime.policy.PolicyManager; import org.elasticsearch.entitlement.runtime.policy.PolicyManager;
import org.elasticsearch.entitlement.runtime.policy.PolicyParser; import org.elasticsearch.entitlement.runtime.policy.PolicyParser;
@ -44,6 +45,9 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.entitlement.runtime.policy.NetworkEntitlement.ACCEPT_ACTION;
import static org.elasticsearch.entitlement.runtime.policy.NetworkEntitlement.CONNECT_ACTION;
import static org.elasticsearch.entitlement.runtime.policy.NetworkEntitlement.LISTEN_ACTION;
import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.ALL_UNNAMED; import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.ALL_UNNAMED;
/** /**
@ -97,7 +101,15 @@ public class EntitlementInitialization {
List.of( List.of(
new Scope("org.elasticsearch.base", List.of(new CreateClassLoaderEntitlement())), new Scope("org.elasticsearch.base", List.of(new CreateClassLoaderEntitlement())),
new Scope("org.elasticsearch.xcontent", List.of(new CreateClassLoaderEntitlement())), new Scope("org.elasticsearch.xcontent", List.of(new CreateClassLoaderEntitlement())),
new Scope("org.elasticsearch.server", List.of(new ExitVMEntitlement(), new CreateClassLoaderEntitlement())) new Scope(
"org.elasticsearch.server",
List.of(
new ExitVMEntitlement(),
new CreateClassLoaderEntitlement(),
new NetworkEntitlement(LISTEN_ACTION | CONNECT_ACTION | ACCEPT_ACTION)
)
),
new Scope("org.apache.httpcomponents.httpclient", List.of(new NetworkEntitlement(CONNECT_ACTION)))
) )
); );
// agents run without a module, so this is a special hack for the apm agent // agents run without a module, so this is a special hack for the apm agent

View file

@ -24,8 +24,11 @@ import java.net.FileNameMap;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.MulticastSocket; import java.net.MulticastSocket;
import java.net.NetworkInterface; import java.net.NetworkInterface;
import java.net.Proxy;
import java.net.ProxySelector; import java.net.ProxySelector;
import java.net.ResponseCache; import java.net.ResponseCache;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.net.SocketImplFactory; import java.net.SocketImplFactory;
import java.net.URL; import java.net.URL;
@ -414,4 +417,91 @@ public class ElasticsearchEntitlementChecker implements EntitlementChecker {
public void check$java_net_MulticastSocket$send(Class<?> callerClass, MulticastSocket that, DatagramPacket p, byte ttl) { public void check$java_net_MulticastSocket$send(Class<?> callerClass, MulticastSocket that, DatagramPacket p, byte ttl) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.CONNECT_ACTION | NetworkEntitlement.ACCEPT_ACTION); policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.CONNECT_ACTION | NetworkEntitlement.ACCEPT_ACTION);
} }
@Override
public void check$java_net_ServerSocket$(Class<?> callerClass, int port) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION);
}
@Override
public void check$java_net_ServerSocket$(Class<?> callerClass, int port, int backlog) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION);
}
@Override
public void check$java_net_ServerSocket$(Class<?> callerClass, int port, int backlog, InetAddress bindAddr) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION);
}
@Override
public void check$java_net_ServerSocket$accept(Class<?> callerClass, ServerSocket that) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.ACCEPT_ACTION);
}
@Override
public void check$java_net_ServerSocket$implAccept(Class<?> callerClass, ServerSocket that, Socket s) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.ACCEPT_ACTION);
}
@Override
public void check$java_net_ServerSocket$bind(Class<?> callerClass, ServerSocket that, SocketAddress endpoint) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION);
}
@Override
public void check$java_net_ServerSocket$bind(Class<?> callerClass, ServerSocket that, SocketAddress endpoint, int backlog) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION);
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, Proxy proxy) {
if (proxy.type() == Proxy.Type.SOCKS || proxy.type() == Proxy.Type.HTTP) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.CONNECT_ACTION);
}
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, String host, int port) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION | NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, InetAddress address, int port) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION | NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, String host, int port, InetAddress localAddr, int localPort) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION | NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, InetAddress address, int port, InetAddress localAddr, int localPort) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION | NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, String host, int port, boolean stream) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION | NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$(Class<?> callerClass, InetAddress host, int port, boolean stream) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION | NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$bind(Class<?> callerClass, Socket that, SocketAddress endpoint) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.LISTEN_ACTION);
}
@Override
public void check$java_net_Socket$connect(Class<?> callerClass, Socket that, SocketAddress endpoint) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.CONNECT_ACTION);
}
@Override
public void check$java_net_Socket$connect(Class<?> callerClass, Socket that, SocketAddress endpoint, int backlog) {
policyManager.checkNetworkAccess(callerClass, NetworkEntitlement.CONNECT_ACTION);
}
} }

View file

@ -58,7 +58,11 @@ public class NetworkEntitlement implements Entitlement {
this.actions = actionsInt; this.actions = actionsInt;
} }
public static Object printActions(int actions) { public NetworkEntitlement(int actions) {
this.actions = actions;
}
public static String printActions(int actions) {
var joiner = new StringJoiner(","); var joiner = new StringJoiner(",");
for (var entry : ACTION_MAP.entrySet()) { for (var entry : ACTION_MAP.entrySet()) {
var action = entry.getValue(); var action = entry.getValue();

View file

@ -200,21 +200,21 @@ public class PolicyManager {
return; return;
} }
ModuleEntitlements entitlements = getEntitlements(requestingClass); ModuleEntitlements entitlements = getEntitlements(requestingClass, NetworkEntitlement.class);
if (entitlements.getEntitlements(NetworkEntitlement.class).anyMatch(n -> n.matchActions(actions))) { if (entitlements.getEntitlements(NetworkEntitlement.class).anyMatch(n -> n.matchActions(actions))) {
logger.debug( logger.debug(
() -> Strings.format( () -> Strings.format(
"Entitled: class [%s], module [%s], entitlement [Network], actions [Ox%X]", "Entitled: class [%s], module [%s], entitlement [network], actions [%s]",
requestingClass, requestingClass,
requestingClass.getModule().getName(), requestingClass.getModule().getName(),
actions NetworkEntitlement.printActions(actions)
) )
); );
return; return;
} }
throw new NotEntitledException( throw new NotEntitledException(
Strings.format( Strings.format(
"Missing entitlement: class [%s], module [%s], entitlement [Network], actions [%s]", "Missing entitlement: class [%s], module [%s], entitlement [network], actions [%s]",
requestingClass, requestingClass,
requestingClass.getModule().getName(), requestingClass.getModule().getName(),
NetworkEntitlement.printActions(actions) NetworkEntitlement.printActions(actions)
@ -228,14 +228,14 @@ public class PolicyManager {
return; return;
} }
ModuleEntitlements entitlements = getEntitlements(requestingClass); ModuleEntitlements entitlements = getEntitlements(requestingClass, entitlementClass);
if (entitlements.hasEntitlement(entitlementClass)) { if (entitlements.hasEntitlement(entitlementClass)) {
logger.debug( logger.debug(
() -> Strings.format( () -> Strings.format(
"Entitled: class [%s], module [%s], entitlement [%s]", "Entitled: class [%s], module [%s], entitlement [%s]",
requestingClass, requestingClass,
requestingClass.getModule().getName(), requestingClass.getModule().getName(),
entitlementClass.getSimpleName() PolicyParser.getEntitlementTypeName(entitlementClass)
) )
); );
return; return;
@ -245,19 +245,22 @@ public class PolicyManager {
"Missing entitlement: class [%s], module [%s], entitlement [%s]", "Missing entitlement: class [%s], module [%s], entitlement [%s]",
requestingClass, requestingClass,
requestingClass.getModule().getName(), requestingClass.getModule().getName(),
entitlementClass.getSimpleName() PolicyParser.getEntitlementTypeName(entitlementClass)
) )
); );
} }
ModuleEntitlements getEntitlements(Class<?> requestingClass) { ModuleEntitlements getEntitlements(Class<?> requestingClass, Class<? extends Entitlement> entitlementClass) {
return moduleEntitlementsMap.computeIfAbsent(requestingClass.getModule(), m -> computeEntitlements(requestingClass)); return moduleEntitlementsMap.computeIfAbsent(
requestingClass.getModule(),
m -> computeEntitlements(requestingClass, entitlementClass)
);
} }
private ModuleEntitlements computeEntitlements(Class<?> requestingClass) { private ModuleEntitlements computeEntitlements(Class<?> requestingClass, Class<? extends Entitlement> entitlementClass) {
Module requestingModule = requestingClass.getModule(); Module requestingModule = requestingClass.getModule();
if (isServerModule(requestingModule)) { if (isServerModule(requestingModule)) {
return getModuleScopeEntitlements(requestingClass, serverEntitlements, requestingModule.getName()); return getModuleScopeEntitlements(requestingClass, serverEntitlements, requestingModule.getName(), "server", entitlementClass);
} }
// plugins // plugins
@ -271,7 +274,7 @@ public class PolicyManager {
} else { } else {
scopeName = requestingModule.getName(); scopeName = requestingModule.getName();
} }
return getModuleScopeEntitlements(requestingClass, pluginEntitlements, scopeName); return getModuleScopeEntitlements(requestingClass, pluginEntitlements, scopeName, pluginName, entitlementClass);
} }
} }
@ -287,11 +290,19 @@ public class PolicyManager {
private ModuleEntitlements getModuleScopeEntitlements( private ModuleEntitlements getModuleScopeEntitlements(
Class<?> callerClass, Class<?> callerClass,
Map<String, List<Entitlement>> scopeEntitlements, Map<String, List<Entitlement>> scopeEntitlements,
String moduleName String moduleName,
String component,
Class<? extends Entitlement> entitlementClass
) { ) {
var entitlements = scopeEntitlements.get(moduleName); var entitlements = scopeEntitlements.get(moduleName);
if (entitlements == null) { if (entitlements == null) {
logger.warn("No applicable entitlement policy for module [{}], class [{}]", moduleName, callerClass); logger.warn(
"No applicable entitlement policy for entitlement [{}] in [{}], module [{}], class [{}]",
PolicyParser.getEntitlementTypeName(entitlementClass),
component,
moduleName,
callerClass
);
return ModuleEntitlements.NONE; return ModuleEntitlements.NONE;
} }
return ModuleEntitlements.from(entitlements); return ModuleEntitlements.from(entitlements);

View file

@ -38,7 +38,7 @@ import static org.hamcrest.Matchers.sameInstance;
public class PolicyManagerTests extends ESTestCase { public class PolicyManagerTests extends ESTestCase {
/** /**
* A module you can use for test cases that don't actually care about the * A module you can use for test cases that don't actually care about the
* entitlements module. * entitlement module.
*/ */
private static Module NO_ENTITLEMENTS_MODULE; private static Module NO_ENTITLEMENTS_MODULE;
@ -66,7 +66,11 @@ public class PolicyManagerTests extends ESTestCase {
var callerClass = this.getClass(); var callerClass = this.getClass();
var requestingModule = callerClass.getModule(); var requestingModule = callerClass.getModule();
assertEquals("No policy for the unnamed module", ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass)); assertEquals(
"No policy for the unnamed module",
ModuleEntitlements.NONE,
policyManager.getEntitlements(callerClass, Entitlement.class)
);
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap); assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
} }
@ -78,7 +82,7 @@ public class PolicyManagerTests extends ESTestCase {
var callerClass = this.getClass(); var callerClass = this.getClass();
var requestingModule = callerClass.getModule(); var requestingModule = callerClass.getModule();
assertEquals("No policy for this plugin", ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass)); assertEquals("No policy for this plugin", ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass, Entitlement.class));
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap); assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
} }
@ -90,11 +94,11 @@ public class PolicyManagerTests extends ESTestCase {
var callerClass = this.getClass(); var callerClass = this.getClass();
var requestingModule = callerClass.getModule(); var requestingModule = callerClass.getModule();
assertEquals(ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass)); assertEquals(ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass, Entitlement.class));
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap); assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
// A second time // A second time
assertEquals(ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass)); assertEquals(ModuleEntitlements.NONE, policyManager.getEntitlements(callerClass, Entitlement.class));
// Nothing new in the map // Nothing new in the map
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap); assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
@ -112,7 +116,7 @@ public class PolicyManagerTests extends ESTestCase {
// Any class from the current module (unnamed) will do // Any class from the current module (unnamed) will do
var callerClass = this.getClass(); var callerClass = this.getClass();
var entitlements = policyManager.getEntitlements(callerClass); var entitlements = policyManager.getEntitlements(callerClass, Entitlement.class);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
} }
@ -126,7 +130,11 @@ public class PolicyManagerTests extends ESTestCase {
var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer"); var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer");
var requestingModule = mockServerClass.getModule(); var requestingModule = mockServerClass.getModule();
assertEquals("No policy for this module in server", ModuleEntitlements.NONE, policyManager.getEntitlements(mockServerClass)); assertEquals(
"No policy for this module in server",
ModuleEntitlements.NONE,
policyManager.getEntitlements(mockServerClass, Entitlement.class)
);
assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap); assertEquals(Map.of(requestingModule, ModuleEntitlements.NONE), policyManager.moduleEntitlementsMap);
} }
@ -145,9 +153,8 @@ public class PolicyManagerTests extends ESTestCase {
// So we use a random module in the boot layer, and a random class from that module (not java.base -- it is // So we use a random module in the boot layer, and a random class from that module (not java.base -- it is
// loaded too early) to mimic a class that would be in the server module. // loaded too early) to mimic a class that would be in the server module.
var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer"); var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer");
var requestingModule = mockServerClass.getModule();
var entitlements = policyManager.getEntitlements(mockServerClass); var entitlements = policyManager.getEntitlements(mockServerClass, Entitlement.class);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat(entitlements.hasEntitlement(ExitVMEntitlement.class), is(true)); assertThat(entitlements.hasEntitlement(ExitVMEntitlement.class), is(true));
} }
@ -167,9 +174,8 @@ public class PolicyManagerTests extends ESTestCase {
var layer = createLayerForJar(jar, "org.example.plugin"); var layer = createLayerForJar(jar, "org.example.plugin");
var mockPluginClass = layer.findLoader("org.example.plugin").loadClass("q.B"); var mockPluginClass = layer.findLoader("org.example.plugin").loadClass("q.B");
var requestingModule = mockPluginClass.getModule();
var entitlements = policyManager.getEntitlements(mockPluginClass); var entitlements = policyManager.getEntitlements(mockPluginClass, Entitlement.class);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat( assertThat(
entitlements.getEntitlements(FileEntitlement.class).toList(), entitlements.getEntitlements(FileEntitlement.class).toList(),
@ -189,11 +195,11 @@ public class PolicyManagerTests extends ESTestCase {
// Any class from the current module (unnamed) will do // Any class from the current module (unnamed) will do
var callerClass = this.getClass(); var callerClass = this.getClass();
var entitlements = policyManager.getEntitlements(callerClass); var entitlements = policyManager.getEntitlements(callerClass, Entitlement.class);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1)); assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1));
var cachedResult = policyManager.moduleEntitlementsMap.values().stream().findFirst().get(); var cachedResult = policyManager.moduleEntitlementsMap.values().stream().findFirst().orElseThrow();
var entitlementsAgain = policyManager.getEntitlements(callerClass); var entitlementsAgain = policyManager.getEntitlements(callerClass, Entitlement.class);
// Nothing new in the map // Nothing new in the map
assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1)); assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1));

View file

@ -0,0 +1,4 @@
ALL-UNNAMED:
- network:
actions:
- connect

View file

@ -0,0 +1,4 @@
org.apache.httpcomponents.httpclient:
- network:
actions:
- connect # for URLHttpClient

View file

@ -0,0 +1,4 @@
ALL-UNNAMED:
- network:
actions:
- connect

View file

@ -0,0 +1,4 @@
org.apache.httpcomponents.httpclient:
- network:
actions:
- connect # For SamlRealm