Add "always denied" network access checks (#119867)

This commit is contained in:
Lorenzo Dematté 2025-01-13 09:26:55 +01:00 committed by GitHub
parent 80729f967a
commit d3a1d9b509
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 166 additions and 9 deletions

View file

@ -15,14 +15,18 @@ import java.io.PrintWriter;
import java.net.ContentHandlerFactory;
import java.net.DatagramSocketImplFactory;
import java.net.FileNameMap;
import java.net.ProxySelector;
import java.net.ResponseCache;
import java.net.SocketImplFactory;
import java.net.URL;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.util.List;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
@SuppressWarnings("unused") // Called from instrumentation code inserted by the Entitlements agent
@ -167,4 +171,22 @@ public interface EntitlementChecker {
void check$java_net_URLConnection$$setContentHandlerFactory(Class<?> callerClass, ContentHandlerFactory fac);
////////////////////
//
// Network access
//
void check$java_net_ProxySelector$$setDefault(Class<?> callerClass, ProxySelector ps);
void check$java_net_ResponseCache$$setDefault(Class<?> callerClass, ResponseCache rc);
void check$java_net_spi_InetAddressResolverProvider$(Class<?> callerClass);
void check$java_net_spi_URLStreamHandlerProvider$(Class<?> callerClass);
void check$java_net_URL$(Class<?> callerClass, String protocol, String host, int port, String file, URLStreamHandler handler);
void check$java_net_URL$(Class<?> callerClass, URL context, String spec, URLStreamHandler handler);
// The only implementation of SSLSession#getSessionContext(); unfortunately it's an interface, so we need to check the implementation
void check$sun_security_ssl_SSLSessionImpl$getSessionContext(Class<?> callerClass, SSLSession sslSession);
}

View file

@ -34,14 +34,19 @@ import org.elasticsearch.rest.RestStatus;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.DatagramSocket;
import java.net.DatagramSocketImpl;
import java.net.DatagramSocketImplFactory;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.ProxySelector;
import java.net.ResponseCache;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.net.spi.InetAddressResolver;
import java.net.spi.InetAddressResolverProvider;
import java.net.spi.URLStreamHandlerProvider;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Map;
@ -50,6 +55,9 @@ import java.util.stream.Collectors;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import static java.util.Map.entry;
import static org.elasticsearch.entitlement.qa.common.RestEntitlementsCheckAction.CheckAction.alwaysDenied;
@ -57,6 +65,7 @@ import static org.elasticsearch.entitlement.qa.common.RestEntitlementsCheckActio
import static org.elasticsearch.entitlement.qa.common.RestEntitlementsCheckAction.CheckAction.forPlugins;
import static org.elasticsearch.rest.RestRequest.Method.GET;
@SuppressWarnings("unused")
public class RestEntitlementsCheckAction extends BaseRestHandler {
private static final Logger logger = LogManager.getLogger(RestEntitlementsCheckAction.class);
public static final Thread NO_OP_SHUTDOWN_HOOK = new Thread(() -> {}, "Shutdown hook for testing");
@ -125,9 +134,87 @@ public class RestEntitlementsCheckAction extends BaseRestHandler {
entry("socket_setSocketImplFactory", alwaysDenied(RestEntitlementsCheckAction::socket$$setSocketImplFactory)),
entry("url_setURLStreamHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::url$$setURLStreamHandlerFactory)),
entry("urlConnection_setFileNameMap", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setFileNameMap)),
entry("urlConnection_setContentHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setContentHandlerFactory))
entry("urlConnection_setContentHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setContentHandlerFactory)),
entry("proxySelector_setDefault", alwaysDenied(RestEntitlementsCheckAction::setDefaultProxySelector)),
entry("responseCache_setDefault", alwaysDenied(RestEntitlementsCheckAction::setDefaultResponseCache)),
entry("createInetAddressResolverProvider", alwaysDenied(RestEntitlementsCheckAction::createInetAddressResolverProvider)),
entry("createURLStreamHandlerProvider", alwaysDenied(RestEntitlementsCheckAction::createURLStreamHandlerProvider)),
entry("createURLWithURLStreamHandler", alwaysDenied(RestEntitlementsCheckAction::createURLWithURLStreamHandler)),
entry("createURLWithURLStreamHandler2", alwaysDenied(RestEntitlementsCheckAction::createURLWithURLStreamHandler2)),
entry("sslSessionImpl_getSessionContext", alwaysDenied(RestEntitlementsCheckAction::sslSessionImplGetSessionContext))
);
private static void createURLStreamHandlerProvider() {
var x = new URLStreamHandlerProvider() {
@Override
public URLStreamHandler createURLStreamHandler(String protocol) {
return null;
}
};
}
private static void sslSessionImplGetSessionContext() {
SSLSocketFactory factory = HttpsURLConnection.getDefaultSSLSocketFactory();
try (SSLSocket socket = (SSLSocket) factory.createSocket()) {
SSLSession session = socket.getSession();
session.getSessionContext();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings("deprecation")
private static void createURLWithURLStreamHandler() {
try {
var x = new URL("http", "host", 1234, "file", new URLStreamHandler() {
@Override
protected URLConnection openConnection(URL u) {
return null;
}
});
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings("deprecation")
private static void createURLWithURLStreamHandler2() {
try {
var x = new URL(null, "spec", new URLStreamHandler() {
@Override
protected URLConnection openConnection(URL u) {
return null;
}
});
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
private static void createInetAddressResolverProvider() {
var x = new InetAddressResolverProvider() {
@Override
public InetAddressResolver get(Configuration configuration) {
return null;
}
@Override
public String name() {
return "TEST";
}
};
}
private static void setDefaultResponseCache() {
ResponseCache.setDefault(null);
}
private static void setDefaultProxySelector() {
ProxySelector.setDefault(null);
}
private static void setDefaultSSLContext() {
try {
SSLContext.setDefault(SSLContext.getDefault());
@ -270,12 +357,7 @@ public class RestEntitlementsCheckAction extends BaseRestHandler {
@SuppressForbidden(reason = "We're required to prevent calls to this forbidden API")
private static void datagramSocket$$setDatagramSocketImplFactory() {
try {
DatagramSocket.setDatagramSocketImplFactory(new DatagramSocketImplFactory() {
@Override
public DatagramSocketImpl createDatagramSocketImpl() {
throw new IllegalStateException();
}
});
DatagramSocket.setDatagramSocketImplFactory(() -> { throw new IllegalStateException(); });
} catch (IOException e) {
throw new IllegalStateException(e);
}

View file

@ -18,14 +18,18 @@ import java.io.PrintWriter;
import java.net.ContentHandlerFactory;
import java.net.DatagramSocketImplFactory;
import java.net.FileNameMap;
import java.net.ProxySelector;
import java.net.ResponseCache;
import java.net.SocketImplFactory;
import java.net.URL;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.util.List;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
/**
@ -310,4 +314,39 @@ public class ElasticsearchEntitlementChecker implements EntitlementChecker {
public void check$javax_net_ssl_SSLContext$$setDefault(Class<?> callerClass, SSLContext context) {
policyManager.checkChangeJVMGlobalState(callerClass);
}
@Override
public void check$java_net_ProxySelector$$setDefault(Class<?> callerClass, ProxySelector ps) {
policyManager.checkChangeNetworkHandling(callerClass);
}
@Override
public void check$java_net_ResponseCache$$setDefault(Class<?> callerClass, ResponseCache rc) {
policyManager.checkChangeNetworkHandling(callerClass);
}
@Override
public void check$java_net_spi_InetAddressResolverProvider$(Class<?> callerClass) {
policyManager.checkChangeNetworkHandling(callerClass);
}
@Override
public void check$java_net_spi_URLStreamHandlerProvider$(Class<?> callerClass) {
policyManager.checkChangeNetworkHandling(callerClass);
}
@Override
public void check$java_net_URL$(Class<?> callerClass, String protocol, String host, int port, String file, URLStreamHandler handler) {
policyManager.checkChangeNetworkHandling(callerClass);
}
@Override
public void check$java_net_URL$(Class<?> callerClass, URL context, String spec, URLStreamHandler handler) {
policyManager.checkChangeNetworkHandling(callerClass);
}
@Override
public void check$sun_security_ssl_SSLSessionImpl$getSessionContext(Class<?> callerClass, SSLSession sslSession) {
policyManager.checkReadSensitiveNetworkInformation(callerClass);
}
}

View file

@ -171,6 +171,20 @@ public class PolicyManager {
});
}
/**
* Check for operations that can modify the way network operations are handled
*/
public void checkChangeNetworkHandling(Class<?> callerClass) {
checkChangeJVMGlobalState(callerClass);
}
/**
* Check for operations that can access sensitive network information, e.g. secrets, tokens or SSL sessions
*/
public void checkReadSensitiveNetworkInformation(Class<?> callerClass) {
neverEntitled(callerClass, "access sensitive network information");
}
private String operationDescription(String methodName) {
// TODO: Use a more human-readable description. Perhaps share code with InstrumentationServiceImpl.parseCheckerMethodName
return methodName.substring(methodName.indexOf('$'));