Remove doPrivileged from ES modules (#127848)

Continuing the cleanup of SecurityManager related code, this commit
removes uses of doPrivileged in all Elasticsearch modules.
This commit is contained in:
Ryan Ernst 2025-05-09 11:15:48 -07:00 committed by GitHub
parent da553b11e3
commit 8ad272352b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 419 additions and 932 deletions

View file

@ -15,8 +15,6 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.telemetry.apm.internal.MetricNameValidator; import org.elasticsearch.telemetry.apm.internal.MetricNameValidator;
import org.elasticsearch.telemetry.metric.Instrument; import org.elasticsearch.telemetry.metric.Instrument;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function; import java.util.function.Function;
@ -35,7 +33,7 @@ public abstract class AbstractInstrument<T> implements Instrument {
public AbstractInstrument(Meter meter, Builder<T> builder) { public AbstractInstrument(Meter meter, Builder<T> builder) {
this.name = builder.getName(); this.name = builder.getName();
this.instrumentBuilder = m -> AccessController.doPrivileged((PrivilegedAction<T>) () -> builder.build(m)); this.instrumentBuilder = m -> builder.build(m);
this.delegate.set(this.instrumentBuilder.apply(meter)); this.delegate.set(this.instrumentBuilder.apply(meter));
} }

View file

@ -20,8 +20,6 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.telemetry.apm.internal.tracing.APMTracer; import org.elasticsearch.telemetry.apm.internal.tracing.APMTracer;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -94,16 +92,13 @@ public class APMAgentSettings {
return; return;
} }
final String completeKey = "elastic.apm." + Objects.requireNonNull(key); final String completeKey = "elastic.apm." + Objects.requireNonNull(key);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> { if (value == null || value.isEmpty()) {
if (value == null || value.isEmpty()) { LOGGER.trace("Clearing system property [{}]", completeKey);
LOGGER.trace("Clearing system property [{}]", completeKey); System.clearProperty(completeKey);
System.clearProperty(completeKey); } else {
} else { LOGGER.trace("Setting setting property [{}] to [{}]", completeKey, value);
LOGGER.trace("Setting setting property [{}] to [{}]", completeKey, value); System.setProperty(completeKey, value);
System.setProperty(completeKey, value); }
}
return null;
});
} }
private static final String TELEMETRY_SETTING_PREFIX = "telemetry."; private static final String TELEMETRY_SETTING_PREFIX = "telemetry.";

View file

@ -18,8 +18,6 @@ import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.telemetry.apm.APMMeterRegistry; import org.elasticsearch.telemetry.apm.APMMeterRegistry;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.function.Supplier; import java.util.function.Supplier;
public class APMMeterService extends AbstractLifecycleComponent { public class APMMeterService extends AbstractLifecycleComponent {
@ -74,7 +72,7 @@ public class APMMeterService extends AbstractLifecycleComponent {
protected Meter createOtelMeter() { protected Meter createOtelMeter() {
assert this.enabled; assert this.enabled;
return AccessController.doPrivileged((PrivilegedAction<Meter>) otelMeterSupplier::get); return otelMeterSupplier.get();
} }
protected Meter createNoopMeter() { protected Meter createNoopMeter() {

View file

@ -39,8 +39,6 @@ import org.elasticsearch.telemetry.apm.internal.APMAgentSettings;
import org.elasticsearch.telemetry.tracing.TraceContext; import org.elasticsearch.telemetry.tracing.TraceContext;
import org.elasticsearch.telemetry.tracing.Traceable; import org.elasticsearch.telemetry.tracing.Traceable;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -145,11 +143,9 @@ public class APMTracer extends AbstractLifecycleComponent implements org.elastic
assert this.enabled; assert this.enabled;
assert this.services == null; assert this.services == null;
return AccessController.doPrivileged((PrivilegedAction<APMServices>) () -> { var openTelemetry = GlobalOpenTelemetry.get();
var openTelemetry = GlobalOpenTelemetry.get(); var tracer = openTelemetry.getTracer("elasticsearch", Build.current().version());
var tracer = openTelemetry.getTracer("elasticsearch", Build.current().version()); return new APMServices(tracer, openTelemetry);
return new APMServices(tracer, openTelemetry);
});
} }
private void destroyApmServices() { private void destroyApmServices() {
@ -175,7 +171,7 @@ public class APMTracer extends AbstractLifecycleComponent implements org.elastic
return; return;
} }
spans.computeIfAbsent(spanId, _spanId -> AccessController.doPrivileged((PrivilegedAction<Context>) () -> { spans.computeIfAbsent(spanId, _spanId -> {
logger.trace("Tracing [{}] [{}]", spanId, spanName); logger.trace("Tracing [{}] [{}]", spanId, spanName);
final SpanBuilder spanBuilder = services.tracer.spanBuilder(spanName); final SpanBuilder spanBuilder = services.tracer.spanBuilder(spanName);
@ -198,7 +194,7 @@ public class APMTracer extends AbstractLifecycleComponent implements org.elastic
updateThreadContext(traceContext, services, contextForNewSpan); updateThreadContext(traceContext, services, contextForNewSpan);
return contextForNewSpan; return contextForNewSpan;
})); });
} }
/** /**
@ -282,8 +278,7 @@ public class APMTracer extends AbstractLifecycleComponent implements org.elastic
public Releasable withScope(Traceable traceable) { public Releasable withScope(Traceable traceable) {
final Context context = spans.get(traceable.getSpanId()); final Context context = spans.get(traceable.getSpanId());
if (context != null) { if (context != null) {
var scope = AccessController.doPrivileged((PrivilegedAction<Scope>) context::makeCurrent); return context.makeCurrent()::close;
return scope::close;
} }
return () -> {}; return () -> {};
} }
@ -380,10 +375,7 @@ public class APMTracer extends AbstractLifecycleComponent implements org.elastic
final var span = Span.fromContextOrNull(spans.remove(traceable.getSpanId())); final var span = Span.fromContextOrNull(spans.remove(traceable.getSpanId()));
if (span != null) { if (span != null) {
logger.trace("Finishing trace [{}]", traceable); logger.trace("Finishing trace [{}]", traceable);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> { span.end();
span.end();
return null;
});
} }
} }
@ -392,10 +384,7 @@ public class APMTracer extends AbstractLifecycleComponent implements org.elastic
*/ */
@Override @Override
public void stopTrace() { public void stopTrace() {
AccessController.doPrivileged((PrivilegedAction<Void>) () -> { Span.current().end();
Span.current().end();
return null;
});
} }
@Override @Override

View file

@ -11,8 +11,6 @@ package org.elasticsearch.ingest.geoip;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -22,9 +20,6 @@ import java.net.Authenticator;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.net.PasswordAuthentication; import java.net.PasswordAuthentication;
import java.net.URL; import java.net.URL;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays; import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
@ -88,46 +83,44 @@ class HttpClient {
final String originalAuthority = new URL(url).getAuthority(); final String originalAuthority = new URL(url).getAuthority();
return doPrivileged(() -> { String innerUrl = url;
String innerUrl = url; HttpURLConnection conn = createConnection(auth, innerUrl);
HttpURLConnection conn = createConnection(auth, innerUrl);
int redirectsCount = 0; int redirectsCount = 0;
while (true) { while (true) {
switch (conn.getResponseCode()) { switch (conn.getResponseCode()) {
case HTTP_OK: case HTTP_OK:
return getInputStream(conn); return getInputStream(conn);
case HTTP_MOVED_PERM: case HTTP_MOVED_PERM:
case HTTP_MOVED_TEMP: case HTTP_MOVED_TEMP:
case HTTP_SEE_OTHER: case HTTP_SEE_OTHER:
if (redirectsCount++ > 50) { if (redirectsCount++ > 50) {
throw new IllegalStateException("too many redirects connection to [" + url + "]"); throw new IllegalStateException("too many redirects connection to [" + url + "]");
} }
// deal with redirections (including relative urls) // deal with redirections (including relative urls)
final String location = conn.getHeaderField("Location"); final String location = conn.getHeaderField("Location");
final URL base = new URL(innerUrl); final URL base = new URL(innerUrl);
final URL next = new URL(base, location); final URL next = new URL(base, location);
innerUrl = next.toExternalForm(); innerUrl = next.toExternalForm();
// compare the *original* authority and the next authority to determine whether to include auth details. // compare the *original* authority and the next authority to determine whether to include auth details.
// this means that the host and port (if it is provided explicitly) are considered. it also means that if we // this means that the host and port (if it is provided explicitly) are considered. it also means that if we
// were to ping-pong back to the original authority, then we'd start including the auth details again. // were to ping-pong back to the original authority, then we'd start including the auth details again.
final String nextAuthority = next.getAuthority(); final String nextAuthority = next.getAuthority();
if (originalAuthority.equals(nextAuthority)) { if (originalAuthority.equals(nextAuthority)) {
conn = createConnection(auth, innerUrl); conn = createConnection(auth, innerUrl);
} else { } else {
conn = createConnection(NO_AUTH, innerUrl); conn = createConnection(NO_AUTH, innerUrl);
} }
break; break;
case HTTP_NOT_FOUND: case HTTP_NOT_FOUND:
throw new ResourceNotFoundException("{} not found", url); throw new ResourceNotFoundException("{} not found", url);
default: default:
int responseCode = conn.getResponseCode(); int responseCode = conn.getResponseCode();
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), url); throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), url);
}
} }
}); }
} }
@SuppressForbidden(reason = "we need socket connection to download data from internet") @SuppressForbidden(reason = "we need socket connection to download data from internet")
@ -150,13 +143,4 @@ class HttpClient {
conn.setInstanceFollowRedirects(false); conn.setInstanceFollowRedirects(false);
return conn; return conn;
} }
private static <R> R doPrivileged(final CheckedSupplier<R, IOException> supplier) throws IOException {
SpecialPermission.check();
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<R>) supplier::get);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
} }

View file

@ -19,8 +19,6 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -492,7 +490,7 @@ public final class WhitelistLoader {
} }
} }
ClassLoader loader = AccessController.doPrivileged((PrivilegedAction<ClassLoader>) owner::getClassLoader); ClassLoader loader = owner.getClassLoader();
return new Whitelist(loader, whitelistClasses, whitelistStatics, whitelistClassBindings, Collections.emptyList()); return new Whitelist(loader, whitelistClasses, whitelistStatics, whitelistClassBindings, Collections.emptyList());
} }

View file

@ -22,8 +22,6 @@ import java.lang.invoke.LambdaConversionException;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.List; import java.util.List;
import static java.lang.invoke.MethodHandles.Lookup; import static java.lang.invoke.MethodHandles.Lookup;
@ -504,9 +502,7 @@ public final class LambdaBootstrap {
byte[] classBytes = cw.toByteArray(); byte[] classBytes = cw.toByteArray();
// DEBUG: // DEBUG:
// new ClassReader(classBytes).accept(new TraceClassVisitor(new PrintWriter(System.out)), ClassReader.SKIP_DEBUG); // new ClassReader(classBytes).accept(new TraceClassVisitor(new PrintWriter(System.out)), ClassReader.SKIP_DEBUG);
return AccessController.doPrivileged( return loader.defineLambda(lambdaClassType.getClassName(), classBytes);
(PrivilegedAction<Class<?>>) () -> loader.defineLambda(lambdaClassType.getClassName(), classBytes)
);
} }
/** /**

View file

@ -27,11 +27,7 @@ import org.objectweb.asm.commons.GeneratorAdapter;
import java.lang.invoke.MethodType; import java.lang.invoke.MethodType;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.security.AccessControlContext;
import java.security.AccessController;
import java.security.Permissions; import java.security.Permissions;
import java.security.PrivilegedAction;
import java.security.ProtectionDomain;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -52,18 +48,12 @@ public final class PainlessScriptEngine implements ScriptEngine {
*/ */
public static final String NAME = "painless"; public static final String NAME = "painless";
/**
* Permissions context used during compilation.
*/
private static final AccessControlContext COMPILATION_CONTEXT;
/* /*
* Setup the allowed permissions. * Setup the allowed permissions.
*/ */
static { static {
final Permissions none = new Permissions(); final Permissions none = new Permissions();
none.setReadOnly(); none.setReadOnly();
COMPILATION_CONTEXT = new AccessControlContext(new ProtectionDomain[] { new ProtectionDomain(null, none) });
} }
/** /**
@ -123,12 +113,7 @@ public final class PainlessScriptEngine implements ScriptEngine {
SpecialPermission.check(); SpecialPermission.check();
// Create our loader (which loads compiled code with no permissions). // Create our loader (which loads compiled code with no permissions).
final Loader loader = AccessController.doPrivileged(new PrivilegedAction<Loader>() { final Loader loader = compiler.createLoader(getClass().getClassLoader());
@Override
public Loader run() {
return compiler.createLoader(getClass().getClassLoader());
}
});
ScriptScope scriptScope = compile(contextsToCompilers.get(context), loader, scriptName, scriptSource, params); ScriptScope scriptScope = compile(contextsToCompilers.get(context), loader, scriptName, scriptSource, params);
@ -398,17 +383,9 @@ public final class PainlessScriptEngine implements ScriptEngine {
try { try {
// Drop all permissions to actually compile the code itself. // Drop all permissions to actually compile the code itself.
return AccessController.doPrivileged(new PrivilegedAction<ScriptScope>() { String name = scriptName == null ? source : scriptName;
@Override return compiler.compile(loader, name, source, compilerSettings);
public ScriptScope run() {
String name = scriptName == null ? source : scriptName;
return compiler.compile(loader, name, source, compilerSettings);
}
}, COMPILATION_CONTEXT);
// Note that it is safe to catch any of the following errors since Painless is stateless. // Note that it is safe to catch any of the following errors since Painless is stateless.
} catch (SecurityException e) {
// security exceptions are rethrown so that they can propagate to the ES log, they are not user errors
throw e;
} catch (OutOfMemoryError | StackOverflowError | LinkageError | Exception e) { } catch (OutOfMemoryError | StackOverflowError | LinkageError | Exception e) {
throw convertToScriptException(source, e); throw convertToScriptException(source, e);
} }

View file

@ -137,10 +137,8 @@ public class AzureStorageCleanupThirdPartyTests extends AbstractThirdPartyReposi
.client("default", LocationMode.PRIMARY_ONLY, randomFrom(OperationPurpose.values())); .client("default", LocationMode.PRIMARY_ONLY, randomFrom(OperationPurpose.values()));
final BlobServiceClient client = azureBlobServiceClient.getSyncClient(); final BlobServiceClient client = azureBlobServiceClient.getSyncClient();
try { try {
SocketAccess.doPrivilegedException(() -> { final BlobContainerClient blobContainer = client.getBlobContainerClient(blobStore.toString());
final BlobContainerClient blobContainer = client.getBlobContainerClient(blobStore.toString()); blobContainer.exists();
return blobContainer.exists();
});
future.onFailure( future.onFailure(
new RuntimeException( new RuntimeException(
"The SAS token used in this test allowed for checking container existence. This test only supports tokens " "The SAS token used in this test allowed for checking container existence. This test only supports tokens "

View file

@ -75,6 +75,7 @@ import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
@ -233,11 +234,8 @@ public class AzureBlobStore implements BlobStore {
final BlobServiceClient client = client(purpose); final BlobServiceClient client = client(purpose);
try { try {
Boolean blobExists = SocketAccess.doPrivilegedException(() -> { final BlobClient azureBlob = client.getBlobContainerClient(container).getBlobClient(blob);
final BlobClient azureBlob = client.getBlobContainerClient(container).getBlobClient(blob); return azureBlob.exists();
return azureBlob.exists();
});
return Boolean.TRUE.equals(blobExists);
} catch (Exception e) { } catch (Exception e) {
logger.trace("can not access [{}] in container {{}}: {}", blob, container, e.getMessage()); logger.trace("can not access [{}] in container {{}}: {}", blob, container, e.getMessage());
throw new IOException("Unable to check if blob " + blob + " exists", e); throw new IOException("Unable to check if blob " + blob + " exists", e);
@ -247,32 +245,26 @@ public class AzureBlobStore implements BlobStore {
public DeleteResult deleteBlobDirectory(OperationPurpose purpose, String path) throws IOException { public DeleteResult deleteBlobDirectory(OperationPurpose purpose, String path) throws IOException {
final AtomicInteger blobsDeleted = new AtomicInteger(0); final AtomicInteger blobsDeleted = new AtomicInteger(0);
final AtomicLong bytesDeleted = new AtomicLong(0); final AtomicLong bytesDeleted = new AtomicLong(0);
final AzureBlobServiceClient client = getAzureBlobServiceClientClient(purpose);
SocketAccess.doPrivilegedVoidException(() -> { final BlobContainerAsyncClient blobContainerAsyncClient = client.getAsyncClient().getBlobContainerAsyncClient(container);
final AzureBlobServiceClient client = getAzureBlobServiceClientClient(purpose); final ListBlobsOptions options = new ListBlobsOptions().setPrefix(path).setDetails(new BlobListDetails().setRetrieveMetadata(true));
final BlobContainerAsyncClient blobContainerAsyncClient = client.getAsyncClient().getBlobContainerAsyncClient(container); final Flux<String> blobsFlux = blobContainerAsyncClient.listBlobs(options).filter(bi -> bi.isPrefix() == false).map(bi -> {
final ListBlobsOptions options = new ListBlobsOptions().setPrefix(path) bytesDeleted.addAndGet(bi.getProperties().getContentLength());
.setDetails(new BlobListDetails().setRetrieveMetadata(true)); blobsDeleted.incrementAndGet();
final Flux<String> blobsFlux = blobContainerAsyncClient.listBlobs(options).filter(bi -> bi.isPrefix() == false).map(bi -> { return bi.getName();
bytesDeleted.addAndGet(bi.getProperties().getContentLength());
blobsDeleted.incrementAndGet();
return bi.getName();
});
deleteListOfBlobs(client, blobsFlux);
}); });
deleteListOfBlobs(client, blobsFlux);
return new DeleteResult(blobsDeleted.get(), bytesDeleted.get()); return new DeleteResult(blobsDeleted.get(), bytesDeleted.get());
} }
void deleteBlobs(OperationPurpose purpose, Iterator<String> blobNames) { void deleteBlobs(OperationPurpose purpose, Iterator<String> blobNames) throws IOException {
if (blobNames.hasNext() == false) { if (blobNames.hasNext() == false) {
return; return;
} }
SocketAccess.doPrivilegedVoidException( deleteListOfBlobs(
() -> deleteListOfBlobs( getAzureBlobServiceClientClient(purpose),
getAzureBlobServiceClientClient(purpose), Flux.fromStream(StreamSupport.stream(Spliterators.spliteratorUnknownSize(blobNames, Spliterator.ORDERED), false))
Flux.fromStream(StreamSupport.stream(Spliterators.spliteratorUnknownSize(blobNames, Spliterator.ORDERED), false))
)
); );
} }
@ -346,17 +338,17 @@ public class AzureBlobStore implements BlobStore {
final BlobServiceClient syncClient = azureBlobServiceClient.getSyncClient(); final BlobServiceClient syncClient = azureBlobServiceClient.getSyncClient();
final BlobServiceAsyncClient asyncClient = azureBlobServiceClient.getAsyncClient(); final BlobServiceAsyncClient asyncClient = azureBlobServiceClient.getAsyncClient();
return SocketAccess.doPrivilegedException(() -> { final BlobContainerClient blobContainerClient = syncClient.getBlobContainerClient(container);
final BlobContainerClient blobContainerClient = syncClient.getBlobContainerClient(container); final BlobClient blobClient = blobContainerClient.getBlobClient(blob);
final BlobClient blobClient = blobContainerClient.getBlobClient(blob); final long totalSize;
final long totalSize; if (length == null) {
if (length == null) { totalSize = blobClient.getProperties().getBlobSize();
totalSize = blobClient.getProperties().getBlobSize(); } else {
} else { totalSize = position + length;
totalSize = position + length; }
} BlobAsyncClient blobAsyncClient = asyncClient.getBlobContainerAsyncClient(container).getBlobAsyncClient(blob);
BlobAsyncClient blobAsyncClient = asyncClient.getBlobContainerAsyncClient(container).getBlobAsyncClient(blob); int maxReadRetries = service.getMaxReadRetries(clientName);
int maxReadRetries = service.getMaxReadRetries(clientName); try {
return new AzureInputStream( return new AzureInputStream(
blobAsyncClient, blobAsyncClient,
position, position,
@ -365,7 +357,9 @@ public class AzureBlobStore implements BlobStore {
maxReadRetries, maxReadRetries,
azureBlobServiceClient.getAllocator() azureBlobServiceClient.getAllocator()
); );
}); } catch (IOException e) {
throw new UncheckedIOException(e);
}
} }
public Map<String, BlobMetadata> listBlobsByPrefix(OperationPurpose purpose, String keyPath, String prefix) throws IOException { public Map<String, BlobMetadata> listBlobsByPrefix(OperationPurpose purpose, String keyPath, String prefix) throws IOException {
@ -373,22 +367,20 @@ public class AzureBlobStore implements BlobStore {
logger.trace(() -> format("listing container [%s], keyPath [%s], prefix [%s]", container, keyPath, prefix)); logger.trace(() -> format("listing container [%s], keyPath [%s], prefix [%s]", container, keyPath, prefix));
try { try {
final BlobServiceClient client = client(purpose); final BlobServiceClient client = client(purpose);
SocketAccess.doPrivilegedVoidException(() -> { final BlobContainerClient containerClient = client.getBlobContainerClient(container);
final BlobContainerClient containerClient = client.getBlobContainerClient(container); final BlobListDetails details = new BlobListDetails().setRetrieveMetadata(true);
final BlobListDetails details = new BlobListDetails().setRetrieveMetadata(true); final ListBlobsOptions listBlobsOptions = new ListBlobsOptions().setPrefix(keyPath + (prefix == null ? "" : prefix))
final ListBlobsOptions listBlobsOptions = new ListBlobsOptions().setPrefix(keyPath + (prefix == null ? "" : prefix)) .setDetails(details);
.setDetails(details);
for (final BlobItem blobItem : containerClient.listBlobsByHierarchy("/", listBlobsOptions, null)) { for (final BlobItem blobItem : containerClient.listBlobsByHierarchy("/", listBlobsOptions, null)) {
BlobItemProperties properties = blobItem.getProperties(); BlobItemProperties properties = blobItem.getProperties();
if (blobItem.isPrefix()) { if (blobItem.isPrefix()) {
continue; continue;
}
String blobName = blobItem.getName().substring(keyPath.length());
blobsBuilder.put(blobName, new BlobMetadata(blobName, properties.getContentLength()));
} }
}); String blobName = blobItem.getName().substring(keyPath.length());
blobsBuilder.put(blobName, new BlobMetadata(blobName, properties.getContentLength()));
}
} catch (Exception e) { } catch (Exception e) {
throw new IOException("Unable to list blobs by prefix [" + prefix + "] for path " + keyPath, e); throw new IOException("Unable to list blobs by prefix [" + prefix + "] for path " + keyPath, e);
} }
@ -401,24 +393,22 @@ public class AzureBlobStore implements BlobStore {
try { try {
final BlobServiceClient client = client(purpose); final BlobServiceClient client = client(purpose);
SocketAccess.doPrivilegedVoidException(() -> { BlobContainerClient blobContainer = client.getBlobContainerClient(container);
BlobContainerClient blobContainer = client.getBlobContainerClient(container); final ListBlobsOptions listBlobsOptions = new ListBlobsOptions();
final ListBlobsOptions listBlobsOptions = new ListBlobsOptions(); listBlobsOptions.setPrefix(keyPath).setDetails(new BlobListDetails().setRetrieveMetadata(true));
listBlobsOptions.setPrefix(keyPath).setDetails(new BlobListDetails().setRetrieveMetadata(true)); for (final BlobItem blobItem : blobContainer.listBlobsByHierarchy("/", listBlobsOptions, null)) {
for (final BlobItem blobItem : blobContainer.listBlobsByHierarchy("/", listBlobsOptions, null)) { Boolean isPrefix = blobItem.isPrefix();
Boolean isPrefix = blobItem.isPrefix(); if (isPrefix != null && isPrefix) {
if (isPrefix != null && isPrefix) { String directoryName = blobItem.getName();
String directoryName = blobItem.getName(); directoryName = directoryName.substring(keyPath.length());
directoryName = directoryName.substring(keyPath.length()); if (directoryName.isEmpty()) {
if (directoryName.isEmpty()) { continue;
continue;
}
// Remove trailing slash
directoryName = directoryName.substring(0, directoryName.length() - 1);
childrenBuilder.put(directoryName, new AzureBlobContainer(BlobPath.EMPTY.add(blobItem.getName()), this));
} }
// Remove trailing slash
directoryName = directoryName.substring(0, directoryName.length() - 1);
childrenBuilder.put(directoryName, new AzureBlobContainer(BlobPath.EMPTY.add(blobItem.getName()), this));
} }
}); }
} catch (Exception e) { } catch (Exception e) {
throw new IOException("Unable to provide children blob containers for " + path, e); throw new IOException("Unable to provide children blob containers for " + path, e);
} }
@ -448,13 +438,8 @@ public class AzureBlobStore implements BlobStore {
return; return;
} }
final String blockId = makeMultipartBlockId(); final String blockId = makeMultipartBlockId();
SocketAccess.doPrivilegedVoidException( blockBlobAsyncClient.stageBlock(blockId, Flux.fromArray(BytesReference.toByteBuffers(buffer.bytes())), buffer.size())
() -> blockBlobAsyncClient.stageBlock( .block();
blockId,
Flux.fromArray(BytesReference.toByteBuffers(buffer.bytes())),
buffer.size()
).block()
);
finishPart(blockId); finishPart(blockId);
} }
@ -464,9 +449,7 @@ public class AzureBlobStore implements BlobStore {
writeBlob(purpose, blobName, buffer.bytes(), failIfAlreadyExists); writeBlob(purpose, blobName, buffer.bytes(), failIfAlreadyExists);
} else { } else {
flushBuffer(); flushBuffer();
SocketAccess.doPrivilegedVoidException( blockBlobAsyncClient.commitBlockList(parts, failIfAlreadyExists == false).block();
() -> blockBlobAsyncClient.commitBlockList(parts, failIfAlreadyExists == false).block()
);
} }
} }
@ -514,20 +497,18 @@ public class AzureBlobStore implements BlobStore {
long blobSize, long blobSize,
boolean failIfAlreadyExists boolean failIfAlreadyExists
) { ) {
SocketAccess.doPrivilegedVoidException(() -> { final BlobServiceAsyncClient asyncClient = asyncClient(purpose);
final BlobServiceAsyncClient asyncClient = asyncClient(purpose);
final BlobAsyncClient blobAsyncClient = asyncClient.getBlobContainerAsyncClient(container).getBlobAsyncClient(blobName); final BlobAsyncClient blobAsyncClient = asyncClient.getBlobContainerAsyncClient(container).getBlobAsyncClient(blobName);
final BlockBlobAsyncClient blockBlobAsyncClient = blobAsyncClient.getBlockBlobAsyncClient(); final BlockBlobAsyncClient blockBlobAsyncClient = blobAsyncClient.getBlockBlobAsyncClient();
final BlockBlobSimpleUploadOptions options = new BlockBlobSimpleUploadOptions(byteBufferFlux, blobSize); final BlockBlobSimpleUploadOptions options = new BlockBlobSimpleUploadOptions(byteBufferFlux, blobSize);
BlobRequestConditions requestConditions = new BlobRequestConditions(); BlobRequestConditions requestConditions = new BlobRequestConditions();
if (failIfAlreadyExists) { if (failIfAlreadyExists) {
requestConditions.setIfNoneMatch("*"); requestConditions.setIfNoneMatch("*");
} }
options.setRequestConditions(requestConditions); options.setRequestConditions(requestConditions);
blockBlobAsyncClient.uploadWithResponse(options).block(); blockBlobAsyncClient.uploadWithResponse(options).block();
});
} }
private void executeMultipartUpload( private void executeMultipartUpload(
@ -537,29 +518,27 @@ public class AzureBlobStore implements BlobStore {
long blobSize, long blobSize,
boolean failIfAlreadyExists boolean failIfAlreadyExists
) { ) {
SocketAccess.doPrivilegedVoidException(() -> { final BlobServiceAsyncClient asyncClient = asyncClient(purpose);
final BlobServiceAsyncClient asyncClient = asyncClient(purpose); final BlobAsyncClient blobAsyncClient = asyncClient.getBlobContainerAsyncClient(container).getBlobAsyncClient(blobName);
final BlobAsyncClient blobAsyncClient = asyncClient.getBlobContainerAsyncClient(container).getBlobAsyncClient(blobName); final BlockBlobAsyncClient blockBlobAsyncClient = blobAsyncClient.getBlockBlobAsyncClient();
final BlockBlobAsyncClient blockBlobAsyncClient = blobAsyncClient.getBlockBlobAsyncClient();
final long partSize = getUploadBlockSize(); final long partSize = getUploadBlockSize();
final Tuple<Long, Long> multiParts = numberOfMultiparts(blobSize, partSize); final Tuple<Long, Long> multiParts = numberOfMultiparts(blobSize, partSize);
final int nbParts = multiParts.v1().intValue(); final int nbParts = multiParts.v1().intValue();
final long lastPartSize = multiParts.v2(); final long lastPartSize = multiParts.v2();
assert blobSize == (((nbParts - 1) * partSize) + lastPartSize) : "blobSize does not match multipart sizes"; assert blobSize == (((nbParts - 1) * partSize) + lastPartSize) : "blobSize does not match multipart sizes";
final List<String> blockIds = new ArrayList<>(nbParts); final List<String> blockIds = new ArrayList<>(nbParts);
for (int i = 0; i < nbParts; i++) { for (int i = 0; i < nbParts; i++) {
final long length = i < nbParts - 1 ? partSize : lastPartSize; final long length = i < nbParts - 1 ? partSize : lastPartSize;
Flux<ByteBuffer> byteBufferFlux = convertStreamToByteBuffer(inputStream, length, DEFAULT_UPLOAD_BUFFERS_SIZE); Flux<ByteBuffer> byteBufferFlux = convertStreamToByteBuffer(inputStream, length, DEFAULT_UPLOAD_BUFFERS_SIZE);
final String blockId = makeMultipartBlockId(); final String blockId = makeMultipartBlockId();
blockBlobAsyncClient.stageBlock(blockId, byteBufferFlux, length).block(); blockBlobAsyncClient.stageBlock(blockId, byteBufferFlux, length).block();
blockIds.add(blockId); blockIds.add(blockId);
} }
blockBlobAsyncClient.commitBlockList(blockIds, failIfAlreadyExists == false).block(); blockBlobAsyncClient.commitBlockList(blockIds, failIfAlreadyExists == false).block();
});
} }
private static final Base64.Encoder base64Encoder = Base64.getEncoder().withoutPadding(); private static final Base64.Encoder base64Encoder = Base64.getEncoder().withoutPadding();
@ -951,16 +930,16 @@ public class AzureBlobStore implements BlobStore {
OptionalBytesReference getRegister(OperationPurpose purpose, String blobPath, String containerPath, String blobKey) { OptionalBytesReference getRegister(OperationPurpose purpose, String blobPath, String containerPath, String blobKey) {
try { try {
return SocketAccess.doPrivilegedException( return OptionalBytesReference.of(
() -> OptionalBytesReference.of( downloadRegisterBlob(
downloadRegisterBlob( containerPath,
containerPath, blobKey,
blobKey, getAzureBlobServiceClientClient(purpose).getSyncClient().getBlobContainerClient(container).getBlobClient(blobPath),
getAzureBlobServiceClientClient(purpose).getSyncClient().getBlobContainerClient(container).getBlobClient(blobPath), null
null
)
) )
); );
} catch (IOException e) {
throw new UncheckedIOException(e);
} catch (Exception e) { } catch (Exception e) {
if (Throwables.getRootCause(e) instanceof BlobStorageException blobStorageException if (Throwables.getRootCause(e) instanceof BlobStorageException blobStorageException
&& blobStorageException.getStatusCode() == RestStatus.NOT_FOUND.getStatus()) { && blobStorageException.getStatusCode() == RestStatus.NOT_FOUND.getStatus()) {
@ -980,17 +959,17 @@ public class AzureBlobStore implements BlobStore {
) { ) {
BlobContainerUtils.ensureValidRegisterContent(updated); BlobContainerUtils.ensureValidRegisterContent(updated);
try { try {
return SocketAccess.doPrivilegedException( return OptionalBytesReference.of(
() -> OptionalBytesReference.of( innerCompareAndExchangeRegister(
innerCompareAndExchangeRegister( containerPath,
containerPath, blobKey,
blobKey, getAzureBlobServiceClientClient(purpose).getSyncClient().getBlobContainerClient(container).getBlobClient(blobPath),
getAzureBlobServiceClientClient(purpose).getSyncClient().getBlobContainerClient(container).getBlobClient(blobPath), expected,
expected, updated
updated
)
) )
); );
} catch (IOException e) {
throw new UncheckedIOException(e);
} catch (Exception e) { } catch (Exception e) {
if (Throwables.getRootCause(e) instanceof BlobStorageException blobStorageException) { if (Throwables.getRootCause(e) instanceof BlobStorageException blobStorageException) {
if (blobStorageException.getStatusCode() == RestStatus.PRECONDITION_FAILED.getStatus() if (blobStorageException.getStatusCode() == RestStatus.PRECONDITION_FAILED.getStatus()

View file

@ -43,7 +43,6 @@ import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.repositories.azure.executors.PrivilegedExecutor;
import org.elasticsearch.repositories.azure.executors.ReactorScheduledExecutorService; import org.elasticsearch.repositories.azure.executors.ReactorScheduledExecutorService;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -52,7 +51,6 @@ import org.elasticsearch.transport.netty4.NettyAllocator;
import java.net.URL; import java.net.URL;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadFactory;
@ -140,10 +138,7 @@ class AzureClientProvider extends AbstractLifecycleComponent {
// Most of the code that needs special permissions (i.e. jackson serializers generation) is executed // Most of the code that needs special permissions (i.e. jackson serializers generation) is executed
// in the event loop executor. That's the reason why we should provide an executor that allows the // in the event loop executor. That's the reason why we should provide an executor that allows the
// execution of privileged code // execution of privileged code
final EventLoopGroup eventLoopGroup = new NioEventLoopGroup( final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(eventLoopThreadsFromSettings(settings), eventLoopExecutor);
eventLoopThreadsFromSettings(settings),
new PrivilegedExecutor(eventLoopExecutor)
);
final TimeValue openConnectionTimeout = OPEN_CONNECTION_TIMEOUT.get(settings); final TimeValue openConnectionTimeout = OPEN_CONNECTION_TIMEOUT.get(settings);
final TimeValue maxIdleTime = MAX_IDLE_TIME.get(settings); final TimeValue maxIdleTime = MAX_IDLE_TIME.get(settings);
@ -210,24 +205,14 @@ class AzureClientProvider extends AbstractLifecycleComponent {
builder.endpoint(secondaryUri); builder.endpoint(secondaryUri);
} }
BlobServiceClient blobServiceClient = SocketAccess.doPrivilegedException(builder::buildClient); BlobServiceClient blobServiceClient = builder.buildClient();
BlobServiceAsyncClient asyncClient = SocketAccess.doPrivilegedException(builder::buildAsyncClient); BlobServiceAsyncClient asyncClient = builder.buildAsyncClient();
return new AzureBlobServiceClient(blobServiceClient, asyncClient, settings.getMaxRetries(), byteBufAllocator); return new AzureBlobServiceClient(blobServiceClient, asyncClient, settings.getMaxRetries(), byteBufAllocator);
} }
@Override @Override
protected void doStart() { protected void doStart() {
ReactorScheduledExecutorService executorService = new ReactorScheduledExecutorService(threadPool, reactorExecutorName) { ReactorScheduledExecutorService executorService = new ReactorScheduledExecutorService(threadPool, reactorExecutorName);
@Override
protected Runnable decorateRunnable(Runnable command) {
return () -> SocketAccess.doPrivilegedVoidException(command::run);
}
@Override
protected <V> Callable<V> decorateCallable(Callable<V> callable) {
return () -> SocketAccess.doPrivilegedException(callable::call);
}
};
// The only way to configure the schedulers used by the SDK is to inject a new global factory. This is a bit ugly... // The only way to configure the schedulers used by the SDK is to inject a new global factory. This is a bit ugly...
// See https://github.com/Azure/azure-sdk-for-java/issues/17272 for a feature request to avoid this need. // See https://github.com/Azure/azure-sdk-for-java/issues/17272 for a feature request to avoid this need.

View file

@ -9,8 +9,6 @@
package org.elasticsearch.repositories.azure; package org.elasticsearch.repositories.azure;
import com.azure.core.util.serializer.JacksonAdapter;
import org.apache.lucene.util.SetOnce; import org.apache.lucene.util.SetOnce;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
@ -28,8 +26,6 @@ import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.NamedXContentRegistry;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@ -44,11 +40,6 @@ public class AzureRepositoryPlugin extends Plugin implements RepositoryPlugin, R
public static final String REPOSITORY_THREAD_POOL_NAME = "repository_azure"; public static final String REPOSITORY_THREAD_POOL_NAME = "repository_azure";
public static final String NETTY_EVENT_LOOP_THREAD_POOL_NAME = "azure_event_loop"; public static final String NETTY_EVENT_LOOP_THREAD_POOL_NAME = "azure_event_loop";
static {
// Trigger static initialization with the plugin class loader so we have access to the proper xml parser
AccessController.doPrivileged((PrivilegedAction<Object>) JacksonAdapter::createDefaultSerializerAdapter);
}
// protected for testing // protected for testing
final SetOnce<AzureStorageService> azureStoreService = new SetOnce<>(); final SetOnce<AzureStorageService> azureStoreService = new SetOnce<>();
private final Settings settings; private final Settings settings;

View file

@ -155,8 +155,7 @@ final class AzureStorageSettings {
this.maxRetries = maxRetries; this.maxRetries = maxRetries;
this.credentialsUsageFeatures = Strings.hasText(key) ? Set.of("uses_key_credentials") this.credentialsUsageFeatures = Strings.hasText(key) ? Set.of("uses_key_credentials")
: Strings.hasText(sasToken) ? Set.of("uses_sas_token") : Strings.hasText(sasToken) ? Set.of("uses_sas_token")
: SocketAccess.doPrivilegedException(() -> System.getenv("AZURE_FEDERATED_TOKEN_FILE")) == null : System.getenv("AZURE_FEDERATED_TOKEN_FILE") == null ? Set.of("uses_default_credentials", "uses_managed_identity")
? Set.of("uses_default_credentials", "uses_managed_identity")
: Set.of("uses_default_credentials", "uses_workload_identity"); : Set.of("uses_default_credentials", "uses_workload_identity");
// Register the proxy if we have any // Register the proxy if we have any

View file

@ -1,60 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.repositories.azure;
import org.apache.logging.log4j.core.util.Throwables;
import org.elasticsearch.SpecialPermission;
import java.io.IOException;
import java.net.SocketPermission;
import java.net.URISyntaxException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
/**
* This plugin uses azure libraries to connect to azure storage services. For these remote calls the plugin needs
* {@link SocketPermission} 'connect' to establish connections. This class wraps the operations requiring access in
* {@link AccessController#doPrivileged(PrivilegedAction)} blocks.
*/
public final class SocketAccess {
private SocketAccess() {}
public static <T> T doPrivilegedException(PrivilegedExceptionAction<T> operation) {
SpecialPermission.check();
try {
return AccessController.doPrivileged(operation);
} catch (PrivilegedActionException e) {
Throwables.rethrow(e.getCause());
assert false : "always throws";
return null;
}
}
public static void doPrivilegedVoidException(StorageRunnable action) {
SpecialPermission.check();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
action.executeCouldThrow();
return null;
});
} catch (PrivilegedActionException e) {
Throwables.rethrow(e.getCause());
}
}
@FunctionalInterface
public interface StorageRunnable {
void executeCouldThrow() throws URISyntaxException, IOException;
}
}

View file

@ -1,30 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.repositories.azure.executors;
import org.elasticsearch.repositories.azure.SocketAccess;
import java.util.concurrent.Executor;
/**
* Executor that grants security permissions to the tasks executed on it.
*/
public class PrivilegedExecutor implements Executor {
private final Executor delegate;
public PrivilegedExecutor(Executor delegate) {
this.delegate = delegate;
}
@Override
public void execute(Runnable command) {
delegate.execute(() -> SocketAccess.doPrivilegedVoidException(command::run));
}
}

View file

@ -49,7 +49,7 @@ public class ReactorScheduledExecutorService extends AbstractExecutorService imp
public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) { public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
Scheduler.ScheduledCancellable schedule = threadPool.schedule(() -> { Scheduler.ScheduledCancellable schedule = threadPool.schedule(() -> {
try { try {
decorateCallable(callable).call(); callable.call();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -59,22 +59,20 @@ public class ReactorScheduledExecutorService extends AbstractExecutorService imp
} }
public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) { public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
Runnable decoratedCommand = decorateRunnable(command); Scheduler.ScheduledCancellable schedule = threadPool.schedule(command, new TimeValue(delay, unit), delegate);
Scheduler.ScheduledCancellable schedule = threadPool.schedule(decoratedCommand, new TimeValue(delay, unit), delegate);
return new ReactorFuture<>(schedule); return new ReactorFuture<>(schedule);
} }
@Override @Override
public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) {
Runnable decoratedCommand = decorateRunnable(command);
return threadPool.scheduler().scheduleAtFixedRate(() -> { return threadPool.scheduler().scheduleAtFixedRate(() -> {
try { try {
delegate.execute(decoratedCommand); delegate.execute(command);
} catch (EsRejectedExecutionException e) { } catch (EsRejectedExecutionException e) {
if (e.isExecutorShutdown()) { if (e.isExecutorShutdown()) {
logger.debug( logger.debug(
() -> format("could not schedule execution of [%s] on [%s] as executor is shut down", decoratedCommand, delegate), () -> format("could not schedule execution of [%s] on [%s] as executor is shut down", command, delegate),
e e
); );
} else { } else {
@ -86,9 +84,7 @@ public class ReactorScheduledExecutorService extends AbstractExecutorService imp
@Override @Override
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) {
Runnable decorateRunnable = decorateRunnable(command); Scheduler.Cancellable cancellable = threadPool.scheduleWithFixedDelay(command, new TimeValue(delay, unit), delegate);
Scheduler.Cancellable cancellable = threadPool.scheduleWithFixedDelay(decorateRunnable, new TimeValue(delay, unit), delegate);
return new ReactorFuture<>(cancellable); return new ReactorFuture<>(cancellable);
} }
@ -120,15 +116,7 @@ public class ReactorScheduledExecutorService extends AbstractExecutorService imp
@Override @Override
public void execute(Runnable command) { public void execute(Runnable command) {
delegate.execute(decorateRunnable(command)); delegate.execute(command);
}
protected Runnable decorateRunnable(Runnable command) {
return command;
}
protected <V> Callable<V> decorateCallable(Callable<V> callable) {
return callable;
} }
private static final class ReactorFuture<V> implements ScheduledFuture<V> { private static final class ReactorFuture<V> implements ScheduledFuture<V> {

View file

@ -52,7 +52,6 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.Channels; import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel; import java.nio.channels.WritableByteChannel;
import java.nio.file.FileAlreadyExistsException; import java.nio.file.FileAlreadyExistsException;
import java.util.ArrayList; import java.util.ArrayList;
@ -172,38 +171,34 @@ class GoogleCloudStorageBlobStore implements BlobStore {
Map<String, BlobMetadata> listBlobsByPrefix(OperationPurpose purpose, String path, String prefix) throws IOException { Map<String, BlobMetadata> listBlobsByPrefix(OperationPurpose purpose, String path, String prefix) throws IOException {
final String pathPrefix = buildKey(path, prefix); final String pathPrefix = buildKey(path, prefix);
final Map<String, BlobMetadata> mapBuilder = new HashMap<>(); final Map<String, BlobMetadata> mapBuilder = new HashMap<>();
SocketAccess.doPrivilegedVoidIOException( client().meteredList(purpose, bucketName, BlobListOption.currentDirectory(), BlobListOption.prefix(pathPrefix))
() -> client().meteredList(purpose, bucketName, BlobListOption.currentDirectory(), BlobListOption.prefix(pathPrefix)) .iterateAll()
.iterateAll() .forEach(blob -> {
.forEach(blob -> { assert blob.getName().startsWith(path);
assert blob.getName().startsWith(path); if (blob.isDirectory() == false) {
if (blob.isDirectory() == false) { final String suffixName = blob.getName().substring(path.length());
final String suffixName = blob.getName().substring(path.length()); mapBuilder.put(suffixName, new BlobMetadata(suffixName, blob.getSize()));
mapBuilder.put(suffixName, new BlobMetadata(suffixName, blob.getSize())); }
} });
})
);
return Map.copyOf(mapBuilder); return Map.copyOf(mapBuilder);
} }
Map<String, BlobContainer> listChildren(OperationPurpose purpose, BlobPath path) throws IOException { Map<String, BlobContainer> listChildren(OperationPurpose purpose, BlobPath path) throws IOException {
final String pathStr = path.buildAsString(); final String pathStr = path.buildAsString();
final Map<String, BlobContainer> mapBuilder = new HashMap<>(); final Map<String, BlobContainer> mapBuilder = new HashMap<>();
SocketAccess.doPrivilegedVoidIOException( client().meteredList(purpose, bucketName, BlobListOption.currentDirectory(), BlobListOption.prefix(pathStr))
() -> client().meteredList(purpose, bucketName, BlobListOption.currentDirectory(), BlobListOption.prefix(pathStr)) .iterateAll()
.iterateAll() .forEach(blob -> {
.forEach(blob -> { if (blob.isDirectory()) {
if (blob.isDirectory()) { assert blob.getName().startsWith(pathStr);
assert blob.getName().startsWith(pathStr); assert blob.getName().endsWith("/");
assert blob.getName().endsWith("/"); // Strip path prefix and trailing slash
// Strip path prefix and trailing slash final String suffixName = blob.getName().substring(pathStr.length(), blob.getName().length() - 1);
final String suffixName = blob.getName().substring(pathStr.length(), blob.getName().length() - 1); if (suffixName.isEmpty() == false) {
if (suffixName.isEmpty() == false) { mapBuilder.put(suffixName, new GoogleCloudStorageBlobContainer(path.add(suffixName), this));
mapBuilder.put(suffixName, new GoogleCloudStorageBlobContainer(path.add(suffixName), this));
}
} }
}) }
); });
return Map.copyOf(mapBuilder); return Map.copyOf(mapBuilder);
} }
@ -216,7 +211,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
*/ */
boolean blobExists(OperationPurpose purpose, String blobName) throws IOException { boolean blobExists(OperationPurpose purpose, String blobName) throws IOException {
final BlobId blobId = BlobId.of(bucketName, blobName); final BlobId blobId = BlobId.of(bucketName, blobName);
final Blob blob = SocketAccess.doPrivilegedIOException(() -> client().meteredGet(purpose, blobId)); final Blob blob = client().meteredGet(purpose, blobId);
return blob != null; return blob != null;
} }
@ -375,9 +370,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
} }
private void initResumableStream() throws IOException { private void initResumableStream() throws IOException {
final var writeChannel = SocketAccess.doPrivilegedIOException( final var writeChannel = client().meteredWriter(purpose, blobInfo, writeOptions);
() -> client().meteredWriter(purpose, blobInfo, writeOptions)
);
channelRef.set(writeChannel); channelRef.set(writeChannel);
resumableStream = new FilterOutputStream(Channels.newOutputStream(new WritableBlobChannel(writeChannel))) { resumableStream = new FilterOutputStream(Channels.newOutputStream(new WritableBlobChannel(writeChannel))) {
@Override @Override
@ -396,7 +389,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
}); });
final WritableByteChannel writeChannel = channelRef.get(); final WritableByteChannel writeChannel = channelRef.get();
if (writeChannel != null) { if (writeChannel != null) {
SocketAccess.doPrivilegedVoidIOException(writeChannel::close); writeChannel.close();
} else { } else {
writeBlob(purpose, blobName, buffer.bytes(), failIfAlreadyExists); writeBlob(purpose, blobName, buffer.bytes(), failIfAlreadyExists);
} }
@ -453,15 +446,13 @@ class GoogleCloudStorageBlobStore implements BlobStore {
} }
for (int retry = 0; retry < 3; ++retry) { for (int retry = 0; retry < 3; ++retry) {
try { try {
final WriteChannel writeChannel = SocketAccess.doPrivilegedIOException( final WriteChannel writeChannel = client().meteredWriter(purpose, blobInfo, writeOptions);
() -> client().meteredWriter(purpose, blobInfo, writeOptions)
);
/* /*
* It is not enough to wrap the call to Streams#copy, we have to wrap the privileged calls too; this is because Streams#copy * It is not enough to wrap the call to Streams#copy, we have to wrap the privileged calls too; this is because Streams#copy
* is in the stacktrace and is not granted the permissions needed to close and write the channel. * is in the stacktrace and is not granted the permissions needed to close and write the channel.
*/ */
org.elasticsearch.core.Streams.copy(inputStream, Channels.newOutputStream(new WritableBlobChannel(writeChannel)), buffer); org.elasticsearch.core.Streams.copy(inputStream, Channels.newOutputStream(new WritableBlobChannel(writeChannel)), buffer);
SocketAccess.doPrivilegedVoidIOException(writeChannel::close); writeChannel.close();
return; return;
} catch (final StorageException se) { } catch (final StorageException se) {
final int errorCode = se.getCode(); final int errorCode = se.getCode();
@ -508,9 +499,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
final Storage.BlobTargetOption[] targetOptions = failIfAlreadyExists final Storage.BlobTargetOption[] targetOptions = failIfAlreadyExists
? new Storage.BlobTargetOption[] { Storage.BlobTargetOption.doesNotExist() } ? new Storage.BlobTargetOption[] { Storage.BlobTargetOption.doesNotExist() }
: new Storage.BlobTargetOption[0]; : new Storage.BlobTargetOption[0];
SocketAccess.doPrivilegedVoidIOException( client().meteredCreate(purpose, blobInfo, buffer, offset, blobSize, targetOptions);
() -> client().meteredCreate(purpose, blobInfo, buffer, offset, blobSize, targetOptions)
);
} catch (final StorageException se) { } catch (final StorageException se) {
if (failIfAlreadyExists && se.getCode() == HTTP_PRECON_FAILED) { if (failIfAlreadyExists && se.getCode() == HTTP_PRECON_FAILED) {
throw new FileAlreadyExistsException(blobInfo.getBlobId().getName(), null, se.getMessage()); throw new FileAlreadyExistsException(blobInfo.getBlobId().getName(), null, se.getMessage());
@ -526,32 +515,30 @@ class GoogleCloudStorageBlobStore implements BlobStore {
* @param pathStr Name of path to delete * @param pathStr Name of path to delete
*/ */
DeleteResult deleteDirectory(OperationPurpose purpose, String pathStr) throws IOException { DeleteResult deleteDirectory(OperationPurpose purpose, String pathStr) throws IOException {
return SocketAccess.doPrivilegedIOException(() -> { DeleteResult deleteResult = DeleteResult.ZERO;
DeleteResult deleteResult = DeleteResult.ZERO; MeteredStorage.MeteredBlobPage meteredPage = client().meteredList(purpose, bucketName, BlobListOption.prefix(pathStr));
MeteredStorage.MeteredBlobPage meteredPage = client().meteredList(purpose, bucketName, BlobListOption.prefix(pathStr)); do {
do { final AtomicLong blobsDeleted = new AtomicLong(0L);
final AtomicLong blobsDeleted = new AtomicLong(0L); final AtomicLong bytesDeleted = new AtomicLong(0L);
final AtomicLong bytesDeleted = new AtomicLong(0L); var blobs = meteredPage.getValues().iterator();
var blobs = meteredPage.getValues().iterator(); deleteBlobs(purpose, new Iterator<>() {
deleteBlobs(purpose, new Iterator<>() { @Override
@Override public boolean hasNext() {
public boolean hasNext() { return blobs.hasNext();
return blobs.hasNext(); }
}
@Override @Override
public String next() { public String next() {
final Blob next = blobs.next(); final Blob next = blobs.next();
blobsDeleted.incrementAndGet(); blobsDeleted.incrementAndGet();
bytesDeleted.addAndGet(next.getSize()); bytesDeleted.addAndGet(next.getSize());
return next.getName(); return next.getName();
} }
}); });
deleteResult = deleteResult.add(blobsDeleted.get(), bytesDeleted.get()); deleteResult = deleteResult.add(blobsDeleted.get(), bytesDeleted.get());
meteredPage = meteredPage.getNextPage(); meteredPage = meteredPage.getNextPage();
} while (meteredPage != null); } while (meteredPage != null);
return deleteResult; return deleteResult;
});
} }
/** /**
@ -577,45 +564,43 @@ class GoogleCloudStorageBlobStore implements BlobStore {
}; };
final List<BlobId> failedBlobs = Collections.synchronizedList(new ArrayList<>()); final List<BlobId> failedBlobs = Collections.synchronizedList(new ArrayList<>());
try { try {
SocketAccess.doPrivilegedVoidIOException(() -> { final AtomicReference<StorageException> ioe = new AtomicReference<>();
final AtomicReference<StorageException> ioe = new AtomicReference<>(); StorageBatch batch = client().batch();
StorageBatch batch = client().batch(); int pendingDeletesInBatch = 0;
int pendingDeletesInBatch = 0; while (blobIdsToDelete.hasNext()) {
while (blobIdsToDelete.hasNext()) { BlobId blob = blobIdsToDelete.next();
BlobId blob = blobIdsToDelete.next(); batch.delete(blob).notify(new BatchResult.Callback<>() {
batch.delete(blob).notify(new BatchResult.Callback<>() { @Override
@Override public void success(Boolean result) {}
public void success(Boolean result) {}
@Override @Override
public void error(StorageException exception) { public void error(StorageException exception) {
if (exception.getCode() != HTTP_NOT_FOUND) { if (exception.getCode() != HTTP_NOT_FOUND) {
// track up to 10 failed blob deletions for the exception message below // track up to 10 failed blob deletions for the exception message below
if (failedBlobs.size() < 10) { if (failedBlobs.size() < 10) {
failedBlobs.add(blob); failedBlobs.add(blob);
} }
if (ioe.compareAndSet(null, exception) == false) { if (ioe.compareAndSet(null, exception) == false) {
ioe.get().addSuppressed(exception); ioe.get().addSuppressed(exception);
}
} }
} }
});
pendingDeletesInBatch++;
if (pendingDeletesInBatch % MAX_DELETES_PER_BATCH == 0) {
batch.submit();
batch = client().batch();
pendingDeletesInBatch = 0;
} }
} });
if (pendingDeletesInBatch > 0) { pendingDeletesInBatch++;
if (pendingDeletesInBatch % MAX_DELETES_PER_BATCH == 0) {
batch.submit(); batch.submit();
batch = client().batch();
pendingDeletesInBatch = 0;
} }
}
if (pendingDeletesInBatch > 0) {
batch.submit();
}
final StorageException exception = ioe.get(); final StorageException exception = ioe.get();
if (exception != null) { if (exception != null) {
throw exception; throw exception;
} }
});
} catch (final Exception e) { } catch (final Exception e) {
throw new IOException("Exception when deleting blobs " + failedBlobs, e); throw new IOException("Exception when deleting blobs " + failedBlobs, e);
} }
@ -644,7 +629,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
@Override @Override
public int write(final ByteBuffer src) throws IOException { public int write(final ByteBuffer src) throws IOException {
try { try {
return SocketAccess.doPrivilegedIOException(() -> channel.write(src)); return channel.write(src);
} catch (IOException e) { } catch (IOException e) {
// BaseStorageWriteChannel#write wraps StorageException in an IOException, but BaseStorageWriteChannel#close // BaseStorageWriteChannel#write wraps StorageException in an IOException, but BaseStorageWriteChannel#close
// does not, if we unwrap StorageExceptions here, it simplifies our retry-on-gone logic // does not, if we unwrap StorageExceptions here, it simplifies our retry-on-gone logic
@ -669,10 +654,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
OptionalBytesReference getRegister(OperationPurpose purpose, String blobName, String container, String key) throws IOException { OptionalBytesReference getRegister(OperationPurpose purpose, String blobName, String container, String key) throws IOException {
final var blobId = BlobId.of(bucketName, blobName); final var blobId = BlobId.of(bucketName, blobName);
try ( try (var meteredReadChannel = client().meteredReader(purpose, blobId); var stream = Channels.newInputStream(meteredReadChannel)) {
var meteredReadChannel = SocketAccess.doPrivilegedIOException(() -> client().meteredReader(purpose, blobId));
var stream = new PrivilegedReadChannelStream(meteredReadChannel)
) {
return OptionalBytesReference.of(BlobContainerUtils.getRegisterUsingConsistentRead(stream, container, key)); return OptionalBytesReference.of(BlobContainerUtils.getRegisterUsingConsistentRead(stream, container, key));
} catch (Exception e) { } catch (Exception e) {
final var serviceException = unwrapServiceException(e); final var serviceException = unwrapServiceException(e);
@ -697,7 +679,7 @@ class GoogleCloudStorageBlobStore implements BlobStore {
BlobContainerUtils.ensureValidRegisterContent(updated); BlobContainerUtils.ensureValidRegisterContent(updated);
final var blobId = BlobId.of(bucketName, blobName); final var blobId = BlobId.of(bucketName, blobName);
final var blob = SocketAccess.doPrivilegedIOException(() -> client().meteredGet(purpose, blobId)); final var blob = client().meteredGet(purpose, blobId);
final long generation; final long generation;
if (blob == null || blob.getGeneration() == null) { if (blob == null || blob.getGeneration() == null) {
@ -708,10 +690,8 @@ class GoogleCloudStorageBlobStore implements BlobStore {
} else { } else {
generation = blob.getGeneration(); generation = blob.getGeneration();
try ( try (
var stream = new PrivilegedReadChannelStream( var stream = Channels.newInputStream(
SocketAccess.doPrivilegedIOException( client().meteredReader(purpose, blobId, Storage.BlobSourceOption.generationMatch(generation))
() -> client().meteredReader(purpose, blobId, Storage.BlobSourceOption.generationMatch(generation))
)
) )
) { ) {
final var witness = BlobContainerUtils.getRegisterUsingConsistentRead(stream, container, key); final var witness = BlobContainerUtils.getRegisterUsingConsistentRead(stream, container, key);
@ -741,15 +721,13 @@ class GoogleCloudStorageBlobStore implements BlobStore {
BaseServiceException finalException = null; BaseServiceException finalException = null;
while (true) { while (true) {
try { try {
SocketAccess.doPrivilegedVoidIOException( client().meteredCreate(
() -> client().meteredCreate( purpose,
purpose, blobInfo,
blobInfo, bytesRef.bytes,
bytesRef.bytes, bytesRef.offset,
bytesRef.offset, bytesRef.length,
bytesRef.length, Storage.BlobTargetOption.generationMatch()
Storage.BlobTargetOption.generationMatch()
)
); );
return OptionalBytesReference.of(expected); return OptionalBytesReference.of(expected);
} catch (Exception e) { } catch (Exception e) {
@ -791,34 +769,4 @@ class GoogleCloudStorageBlobStore implements BlobStore {
} }
return null; return null;
} }
private static final class PrivilegedReadChannelStream extends InputStream {
private final InputStream stream;
PrivilegedReadChannelStream(ReadableByteChannel channel) {
stream = Channels.newInputStream(channel);
}
@Override
public int read(byte[] b) throws IOException {
return SocketAccess.doPrivilegedIOException(() -> stream.read(b));
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
return SocketAccess.doPrivilegedIOException(() -> stream.read(b, off, len));
}
@Override
public void close() throws IOException {
SocketAccess.doPrivilegedVoidIOException(stream::close);
}
@Override
public int read() throws IOException {
return SocketAccess.doPrivilegedIOException(stream::read);
}
}
} }

View file

@ -255,14 +255,12 @@ public class GoogleCloudStorageClientSettings {
} }
try (InputStream credStream = credentialsFileSetting.get(settings)) { try (InputStream credStream = credentialsFileSetting.get(settings)) {
final Collection<String> scopes = Collections.singleton(StorageScopes.DEVSTORAGE_FULL_CONTROL); final Collection<String> scopes = Collections.singleton(StorageScopes.DEVSTORAGE_FULL_CONTROL);
return SocketAccess.doPrivilegedIOException(() -> { NetHttpTransport netHttpTransport = new NetHttpTransport.Builder().setProxy(proxy).build();
NetHttpTransport netHttpTransport = new NetHttpTransport.Builder().setProxy(proxy).build(); final ServiceAccountCredentials credentials = ServiceAccountCredentials.fromStream(credStream, () -> netHttpTransport);
final ServiceAccountCredentials credentials = ServiceAccountCredentials.fromStream(credStream, () -> netHttpTransport); if (credentials.createScopedRequired()) {
if (credentials.createScopedRequired()) { return (ServiceAccountCredentials) credentials.createScoped(scopes);
return (ServiceAccountCredentials) credentials.createScoped(scopes); }
} return credentials;
return credentials;
});
} }
} catch (final Exception e) { } catch (final Exception e) {
throw new IllegalArgumentException("failed to load GCS client credentials from [" + credentialsFileSetting.getKey() + "]", e); throw new IllegalArgumentException("failed to load GCS client credentials from [" + credentialsFileSetting.getKey() + "]", e);

View file

@ -81,23 +81,21 @@ class GoogleCloudStorageRetryingInputStream extends InputStream {
try { try {
return RetryHelper.runWithRetries(() -> { return RetryHelper.runWithRetries(() -> {
try { try {
return SocketAccess.doPrivilegedIOException(() -> { final var meteredGet = client.meteredObjectsGet(purpose, blobId.getBucket(), blobId.getName());
final var meteredGet = client.meteredObjectsGet(purpose, blobId.getBucket(), blobId.getName()); meteredGet.setReturnRawInputStream(true);
meteredGet.setReturnRawInputStream(true);
if (currentOffset > 0 || start > 0 || end < Long.MAX_VALUE - 1) { if (currentOffset > 0 || start > 0 || end < Long.MAX_VALUE - 1) {
if (meteredGet.getRequestHeaders() != null) { if (meteredGet.getRequestHeaders() != null) {
meteredGet.getRequestHeaders().setRange("bytes=" + Math.addExact(start, currentOffset) + "-" + end); meteredGet.getRequestHeaders().setRange("bytes=" + Math.addExact(start, currentOffset) + "-" + end);
}
} }
final HttpResponse resp = meteredGet.executeMedia(); }
final Long contentLength = resp.getHeaders().getContentLength(); final HttpResponse resp = meteredGet.executeMedia();
InputStream content = resp.getContent(); final Long contentLength = resp.getHeaders().getContentLength();
if (contentLength != null) { InputStream content = resp.getContent();
content = new ContentLengthValidatingInputStream(content, contentLength); if (contentLength != null) {
} content = new ContentLengthValidatingInputStream(content, contentLength);
return content; }
}); return content;
} catch (IOException e) { } catch (IOException e) {
throw StorageException.translate(e); throw StorageException.translate(e);
} }

View file

@ -148,12 +148,14 @@ public class GoogleCloudStorageService {
*/ */
private MeteredStorage createClient(GoogleCloudStorageClientSettings gcsClientSettings, GcsRepositoryStatsCollector statsCollector) private MeteredStorage createClient(GoogleCloudStorageClientSettings gcsClientSettings, GcsRepositoryStatsCollector statsCollector)
throws IOException { throws IOException {
final HttpTransport httpTransport = SocketAccess.doPrivilegedIOException(() -> {
final NetHttpTransport.Builder builder = new NetHttpTransport.Builder(); final NetHttpTransport.Builder builder = new NetHttpTransport.Builder();
// requires java.lang.RuntimePermission "setFactory" // requires java.lang.RuntimePermission "setFactory"
// Pin the TLS trust certificates. // Pin the TLS trust certificates.
// We manually load the key store from jks instead of using GoogleUtils.getCertificateTrustStore() because that uses a .p12 // We manually load the key store from jks instead of using GoogleUtils.getCertificateTrustStore() because that uses a .p12
// store format not compatible with FIPS mode. // store format not compatible with FIPS mode.
final HttpTransport httpTransport;
try {
final KeyStore certTrustStore = SecurityUtils.getJavaKeyStore(); final KeyStore certTrustStore = SecurityUtils.getJavaKeyStore();
try (InputStream keyStoreStream = GoogleUtils.class.getResourceAsStream("google.jks")) { try (InputStream keyStoreStream = GoogleUtils.class.getResourceAsStream("google.jks")) {
SecurityUtils.loadKeyStore(certTrustStore, keyStoreStream, "notasecret"); SecurityUtils.loadKeyStore(certTrustStore, keyStoreStream, "notasecret");
@ -164,8 +166,12 @@ public class GoogleCloudStorageService {
builder.setProxy(proxy); builder.setProxy(proxy);
notifyProxyIsSet(proxy); notifyProxyIsSet(proxy);
} }
return builder.build(); httpTransport = builder.build();
}); } catch (RuntimeException | IOException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
final HttpTransportOptions httpTransportOptions = new HttpTransportOptions( final HttpTransportOptions httpTransportOptions = new HttpTransportOptions(
HttpTransportOptions.newBuilder() HttpTransportOptions.newBuilder()
@ -209,7 +215,7 @@ public class GoogleCloudStorageService {
} else { } else {
String defaultProjectId = null; String defaultProjectId = null;
try { try {
defaultProjectId = SocketAccess.doPrivilegedIOException(ServiceOptions::getDefaultProjectId); defaultProjectId = ServiceOptions.getDefaultProjectId();
if (defaultProjectId != null) { if (defaultProjectId != null) {
storageOptionsBuilder.setProjectId(defaultProjectId); storageOptionsBuilder.setProjectId(defaultProjectId);
} }
@ -220,12 +226,10 @@ public class GoogleCloudStorageService {
try { try {
// fallback to manually load project ID here as the above ServiceOptions method has the metadata endpoint hardcoded, // fallback to manually load project ID here as the above ServiceOptions method has the metadata endpoint hardcoded,
// which makes it impossible to test // which makes it impossible to test
SocketAccess.doPrivilegedVoidIOException(() -> { final String projectId = getDefaultProjectId(gcsClientSettings.getProxy());
final String projectId = getDefaultProjectId(gcsClientSettings.getProxy()); if (projectId != null) {
if (projectId != null) { storageOptionsBuilder.setProjectId(projectId);
storageOptionsBuilder.setProjectId(projectId); }
}
});
} catch (Exception e) { } catch (Exception e) {
logger.warn("failed to load default project id fallback", e); logger.warn("failed to load default project id fallback", e);
} }
@ -233,7 +237,7 @@ public class GoogleCloudStorageService {
} }
if (gcsClientSettings.getCredential() == null) { if (gcsClientSettings.getCredential() == null) {
try { try {
storageOptionsBuilder.setCredentials(SocketAccess.doPrivilegedIOException(GoogleCredentials::getApplicationDefault)); storageOptionsBuilder.setCredentials(GoogleCredentials.getApplicationDefault());
} catch (Exception e) { } catch (Exception e) {
logger.warn("failed to load Application Default Credentials", e); logger.warn("failed to load Application Default Credentials", e);
} }

View file

@ -30,8 +30,6 @@ import org.elasticsearch.core.SuppressForbidden;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Iterator; import java.util.Iterator;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -64,17 +62,15 @@ public class MeteredStorage {
@SuppressForbidden(reason = "need access to storage client") @SuppressForbidden(reason = "need access to storage client")
private static com.google.api.services.storage.Storage getStorageRpc(Storage client) { private static com.google.api.services.storage.Storage getStorageRpc(Storage client) {
return AccessController.doPrivileged((PrivilegedAction<com.google.api.services.storage.Storage>) () -> { assert client.getOptions().getRpc() instanceof HttpStorageRpc;
assert client.getOptions().getRpc() instanceof HttpStorageRpc; assert Stream.of(client.getOptions().getRpc().getClass().getDeclaredFields()).anyMatch(f -> f.getName().equals("storage"));
assert Stream.of(client.getOptions().getRpc().getClass().getDeclaredFields()).anyMatch(f -> f.getName().equals("storage")); try {
try { final Field storageField = client.getOptions().getRpc().getClass().getDeclaredField("storage");
final Field storageField = client.getOptions().getRpc().getClass().getDeclaredField("storage"); storageField.setAccessible(true);
storageField.setAccessible(true); return (com.google.api.services.storage.Storage) storageField.get(client.getOptions().getRpc());
return (com.google.api.services.storage.Storage) storageField.get(client.getOptions().getRpc()); } catch (Exception e) {
} catch (Exception e) { throw new IllegalStateException("storage could not be set up", e);
throw new IllegalStateException("storage could not be set up", e); }
}
});
} }
public MeteredBlobPage meteredList(OperationPurpose purpose, String bucket, Storage.BlobListOption... options) throws IOException { public MeteredBlobPage meteredList(OperationPurpose purpose, String bucket, Storage.BlobListOption... options) throws IOException {

View file

@ -1,62 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.repositories.gcs;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.core.CheckedRunnable;
import java.io.IOException;
import java.net.SocketPermission;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
/**
* This plugin uses google api/client libraries to connect to google cloud services. For these remote calls the plugin
* needs {@link SocketPermission} 'connect' to establish connections. This class wraps the operations requiring access
* in {@link AccessController#doPrivileged(PrivilegedAction)} blocks.
*/
final class SocketAccess {
private SocketAccess() {}
public static <T> T doPrivilegedIOException(PrivilegedExceptionAction<T> operation) throws IOException {
SpecialPermission.check();
try {
return AccessController.doPrivileged(operation);
} catch (PrivilegedActionException e) {
throw causeAsIOException(e);
}
}
public static void doPrivilegedVoidIOException(CheckedRunnable<IOException> action) throws IOException {
SpecialPermission.check();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
action.run();
return null;
});
} catch (PrivilegedActionException e) {
throw causeAsIOException(e);
}
}
private static IOException causeAsIOException(PrivilegedActionException e) {
final Throwable cause = e.getCause();
if (cause instanceof IOException ioException) {
return ioException;
}
if (cause instanceof RuntimeException runtimeException) {
throw runtimeException;
}
throw new RuntimeException(cause);
}
}

View file

@ -188,7 +188,7 @@ public class GoogleCloudStorageClientSettingsTests extends ESTestCase {
var proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(InetAddress.getLoopbackAddress(), proxyServer.getPort())); var proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(InetAddress.getLoopbackAddress(), proxyServer.getPort()));
ServiceAccountCredentials credentials = loadCredential(settings, clientName, proxy); ServiceAccountCredentials credentials = loadCredential(settings, clientName, proxy);
assertNotNull(credentials); assertNotNull(credentials);
assertEquals("proxy_access_token", SocketAccess.doPrivilegedIOException(credentials::refreshAccessToken).getTokenValue()); assertEquals("proxy_access_token", credentials.refreshAccessToken().getTokenValue());
} }
} }

View file

@ -198,7 +198,7 @@ public class GoogleCloudStorageServiceTests extends ESTestCase {
}; };
try (proxyServer) { try (proxyServer) {
var proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(InetAddress.getLoopbackAddress(), proxyServer.getPort())); var proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(InetAddress.getLoopbackAddress(), proxyServer.getPort()));
assertEquals(proxyProjectId, SocketAccess.doPrivilegedIOException(() -> GoogleCloudStorageService.getDefaultProjectId(proxy))); assertEquals(proxyProjectId, GoogleCloudStorageService.getDefaultProjectId(proxy));
} }
} }
} }

View file

@ -36,11 +36,7 @@ public class MockHttpProxyServerTests extends ESTestCase {
var httpClient = HttpClients.custom() var httpClient = HttpClients.custom()
.setRoutePlanner(new DefaultProxyRoutePlanner(new HttpHost(InetAddress.getLoopbackAddress(), proxyServer.getPort()))) .setRoutePlanner(new DefaultProxyRoutePlanner(new HttpHost(InetAddress.getLoopbackAddress(), proxyServer.getPort())))
.build(); .build();
try ( try (proxyServer; httpClient; var httpResponse = httpClient.execute(new HttpGet("http://googleapis.com/"))) {
proxyServer;
httpClient;
var httpResponse = SocketAccess.doPrivilegedIOException(() -> httpClient.execute(new HttpGet("http://googleapis.com/")))
) {
assertEquals(httpBody.length(), httpResponse.getEntity().getContentLength()); assertEquals(httpBody.length(), httpResponse.getEntity().getContentLength());
assertEquals(httpBody, EntityUtils.toString(httpResponse.getEntity())); assertEquals(httpBody, EntityUtils.toString(httpResponse.getEntity()));
} }

View file

@ -105,7 +105,7 @@ class S3BlobContainer extends AbstractBlobContainer {
@Override @Override
public boolean blobExists(OperationPurpose purpose, String blobName) { public boolean blobExists(OperationPurpose purpose, String blobName) {
try (AmazonS3Reference clientReference = blobStore.clientReference()) { try (AmazonS3Reference clientReference = blobStore.clientReference()) {
return SocketAccess.doPrivileged(() -> doesObjectExist(purpose, clientReference, blobStore.bucket(), buildKey(blobName))); return doesObjectExist(purpose, clientReference, blobStore.bucket(), buildKey(blobName));
} catch (final Exception e) { } catch (final Exception e) {
throw new BlobStoreException("Failed to check if blob [" + blobName + "] exists", e); throw new BlobStoreException("Failed to check if blob [" + blobName + "] exists", e);
} }
@ -145,14 +145,11 @@ class S3BlobContainer extends AbstractBlobContainer {
throws IOException { throws IOException {
assert BlobContainer.assertPurposeConsistency(purpose, blobName); assert BlobContainer.assertPurposeConsistency(purpose, blobName);
assert inputStream.markSupported() : "No mark support on inputStream breaks the S3 SDK's ability to retry requests"; assert inputStream.markSupported() : "No mark support on inputStream breaks the S3 SDK's ability to retry requests";
SocketAccess.doPrivilegedIOException(() -> { if (blobSize <= getLargeBlobThresholdInBytes()) {
if (blobSize <= getLargeBlobThresholdInBytes()) { executeSingleUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize);
executeSingleUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize); } else {
} else { executeMultipartUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize);
executeMultipartUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize); }
}
return null;
});
} }
@Override @Override
@ -186,13 +183,9 @@ class S3BlobContainer extends AbstractBlobContainer {
assert lastPart == false : "use single part upload if there's only a single part"; assert lastPart == false : "use single part upload if there's only a single part";
try (var clientReference = blobStore.clientReference()) { try (var clientReference = blobStore.clientReference()) {
uploadId.set( uploadId.set(
SocketAccess.doPrivileged( clientReference.client()
() -> clientReference.client() .createMultipartUpload(createMultipartUpload(purpose, Operation.PUT_MULTIPART_OBJECT, absoluteBlobKey))
.createMultipartUpload( .uploadId()
createMultipartUpload(purpose, Operation.PUT_MULTIPART_OBJECT, absoluteBlobKey)
)
.uploadId()
)
); );
} }
if (Strings.isEmpty(uploadId.get())) { if (Strings.isEmpty(uploadId.get())) {
@ -211,10 +204,8 @@ class S3BlobContainer extends AbstractBlobContainer {
final InputStream partContentStream = buffer.bytes().streamInput(); final InputStream partContentStream = buffer.bytes().streamInput();
final UploadPartResponse uploadResponse; final UploadPartResponse uploadResponse;
try (var clientReference = blobStore.clientReference()) { try (var clientReference = blobStore.clientReference()) {
uploadResponse = SocketAccess.doPrivileged( uploadResponse = clientReference.client()
() -> clientReference.client() .uploadPart(uploadRequest, RequestBody.fromInputStream(partContentStream, buffer.size()));
.uploadPart(uploadRequest, RequestBody.fromInputStream(partContentStream, buffer.size()))
);
} }
finishPart(CompletedPart.builder().partNumber(parts.size() + 1).eTag(uploadResponse.eTag()).build()); finishPart(CompletedPart.builder().partNumber(parts.size() + 1).eTag(uploadResponse.eTag()).build());
} }
@ -238,9 +229,7 @@ class S3BlobContainer extends AbstractBlobContainer {
); );
final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build(); final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build();
try (var clientReference = blobStore.clientReference()) { try (var clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid( clientReference.client().completeMultipartUpload(completeMultipartUploadRequest);
() -> clientReference.client().completeMultipartUpload(completeMultipartUploadRequest)
);
} }
} }
} }
@ -300,7 +289,7 @@ class S3BlobContainer extends AbstractBlobContainer {
S3BlobStore.configureRequestForMetrics(abortMultipartUploadRequestBuilder, blobStore, Operation.ABORT_MULTIPART_OBJECT, purpose); S3BlobStore.configureRequestForMetrics(abortMultipartUploadRequestBuilder, blobStore, Operation.ABORT_MULTIPART_OBJECT, purpose);
final var abortMultipartUploadRequest = abortMultipartUploadRequestBuilder.build(); final var abortMultipartUploadRequest = abortMultipartUploadRequestBuilder.build();
try (var clientReference = blobStore.clientReference()) { try (var clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> clientReference.client().abortMultipartUpload(abortMultipartUploadRequest)); clientReference.client().abortMultipartUpload(abortMultipartUploadRequest);
} }
} }
@ -391,7 +380,7 @@ class S3BlobContainer extends AbstractBlobContainer {
S3BlobStore.configureRequestForMetrics(copyObjectRequestBuilder, blobStore, Operation.COPY_OBJECT, purpose); S3BlobStore.configureRequestForMetrics(copyObjectRequestBuilder, blobStore, Operation.COPY_OBJECT, purpose);
final var copyObjectRequest = copyObjectRequestBuilder.build(); final var copyObjectRequest = copyObjectRequestBuilder.build();
try (AmazonS3Reference clientReference = blobStore.clientReference()) { try (AmazonS3Reference clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> clientReference.client().copyObject(copyObjectRequest)); clientReference.client().copyObject(copyObjectRequest);
} }
} }
} catch (final SdkException e) { } catch (final SdkException e) {
@ -417,7 +406,7 @@ class S3BlobContainer extends AbstractBlobContainer {
listObjectsRequestBuilder.continuationToken(prevListing.nextContinuationToken()); listObjectsRequestBuilder.continuationToken(prevListing.nextContinuationToken());
} }
final var listObjectsRequest = listObjectsRequestBuilder.build(); final var listObjectsRequest = listObjectsRequestBuilder.build();
final var listObjectsResponse = SocketAccess.doPrivileged(() -> clientReference.client().listObjectsV2(listObjectsRequest)); final var listObjectsResponse = clientReference.client().listObjectsV2(listObjectsRequest);
final Iterator<String> blobNameIterator = Iterators.map(listObjectsResponse.contents().iterator(), s3Object -> { final Iterator<String> blobNameIterator = Iterators.map(listObjectsResponse.contents().iterator(), s3Object -> {
deletedBlobs.incrementAndGet(); deletedBlobs.incrementAndGet();
deletedBytes.addAndGet(s3Object.size()); deletedBytes.addAndGet(s3Object.size());
@ -539,7 +528,7 @@ class S3BlobContainer extends AbstractBlobContainer {
} }
S3BlobStore.configureRequestForMetrics(listObjectsRequestBuilder, blobStore, Operation.LIST_OBJECTS, operationPurpose); S3BlobStore.configureRequestForMetrics(listObjectsRequestBuilder, blobStore, Operation.LIST_OBJECTS, operationPurpose);
final var listObjectsRequest = listObjectsRequestBuilder.build(); final var listObjectsRequest = listObjectsRequestBuilder.build();
return SocketAccess.doPrivileged(() -> clientReference.client().listObjectsV2(listObjectsRequest)); return clientReference.client().listObjectsV2(listObjectsRequest);
} }
} }
@ -579,9 +568,7 @@ class S3BlobContainer extends AbstractBlobContainer {
S3BlobStore.configureRequestForMetrics(putRequestBuilder, blobStore, Operation.PUT_OBJECT, purpose); S3BlobStore.configureRequestForMetrics(putRequestBuilder, blobStore, Operation.PUT_OBJECT, purpose);
final var putRequest = putRequestBuilder.build(); final var putRequest = putRequestBuilder.build();
SocketAccess.doPrivilegedVoid( clientReference.client().putObject(putRequest, RequestBody.fromInputStream(input, blobSize));
() -> clientReference.client().putObject(putRequest, RequestBody.fromInputStream(input, blobSize))
);
} catch (final SdkException e) { } catch (final SdkException e) {
throw new IOException("Unable to upload object [" + blobName + "] using a single upload", e); throw new IOException("Unable to upload object [" + blobName + "] using a single upload", e);
} }
@ -618,9 +605,7 @@ class S3BlobContainer extends AbstractBlobContainer {
try { try {
final String uploadId; final String uploadId;
try (AmazonS3Reference clientReference = s3BlobStore.clientReference()) { try (AmazonS3Reference clientReference = s3BlobStore.clientReference()) {
uploadId = SocketAccess.doPrivileged( uploadId = clientReference.client().createMultipartUpload(createMultipartUpload(purpose, operation, blobName)).uploadId();
() -> clientReference.client().createMultipartUpload(createMultipartUpload(purpose, operation, blobName)).uploadId()
);
cleanupOnFailureActions.add(() -> abortMultiPartUpload(purpose, uploadId, blobName)); cleanupOnFailureActions.add(() -> abortMultiPartUpload(purpose, uploadId, blobName));
} }
if (Strings.isEmpty(uploadId)) { if (Strings.isEmpty(uploadId)) {
@ -657,7 +642,7 @@ class S3BlobContainer extends AbstractBlobContainer {
S3BlobStore.configureRequestForMetrics(completeMultipartUploadRequestBuilder, blobStore, operation, purpose); S3BlobStore.configureRequestForMetrics(completeMultipartUploadRequestBuilder, blobStore, operation, purpose);
final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build(); final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build();
try (var clientReference = s3BlobStore.clientReference()) { try (var clientReference = s3BlobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> clientReference.client().completeMultipartUpload(completeMultipartUploadRequest)); clientReference.client().completeMultipartUpload(completeMultipartUploadRequest);
} }
cleanupOnFailureActions.clear(); cleanupOnFailureActions.clear();
} catch (final SdkException e) { } catch (final SdkException e) {
@ -691,10 +676,8 @@ class S3BlobContainer extends AbstractBlobContainer {
final UploadPartRequest uploadRequest = createPartUploadRequest(purpose, uploadId, partNum, blobName, partSize, lastPart); final UploadPartRequest uploadRequest = createPartUploadRequest(purpose, uploadId, partNum, blobName, partSize, lastPart);
try (var clientReference = s3BlobStore.clientReference()) { try (var clientReference = s3BlobStore.clientReference()) {
final UploadPartResponse uploadResponse = SocketAccess.doPrivileged( final UploadPartResponse uploadResponse = clientReference.client()
() -> clientReference.client().uploadPart(uploadRequest, RequestBody.fromInputStream(input, partSize)) .uploadPart(uploadRequest, RequestBody.fromInputStream(input, partSize));
);
return CompletedPart.builder().partNumber(partNum).eTag(uploadResponse.eTag()).build(); return CompletedPart.builder().partNumber(partNum).eTag(uploadResponse.eTag()).build();
} }
} }
@ -741,9 +724,7 @@ class S3BlobContainer extends AbstractBlobContainer {
final var uploadPartCopyRequest = uploadPartCopyRequestBuilder.build(); final var uploadPartCopyRequest = uploadPartCopyRequestBuilder.build();
try (AmazonS3Reference clientReference = blobStore.clientReference()) { try (AmazonS3Reference clientReference = blobStore.clientReference()) {
final var uploadPartCopyResponse = SocketAccess.doPrivileged( final var uploadPartCopyResponse = clientReference.client().uploadPartCopy(uploadPartCopyRequest);
() -> clientReference.client().uploadPartCopy(uploadPartCopyRequest)
);
return CompletedPart.builder().partNumber(partNum).eTag(uploadPartCopyResponse.copyPartResult().eTag()).build(); return CompletedPart.builder().partNumber(partNum).eTag(uploadPartCopyResponse.copyPartResult().eTag()).build();
} }
}) })
@ -934,7 +915,7 @@ class S3BlobContainer extends AbstractBlobContainer {
S3BlobStore.configureRequestForMetrics(listRequestBuilder, blobStore, Operation.LIST_OBJECTS, purpose); S3BlobStore.configureRequestForMetrics(listRequestBuilder, blobStore, Operation.LIST_OBJECTS, purpose);
final var listRequest = listRequestBuilder.build(); final var listRequest = listRequestBuilder.build();
try { try {
return SocketAccess.doPrivileged(() -> client.listMultipartUploads(listRequest)).uploads(); return client.listMultipartUploads(listRequest).uploads();
} catch (SdkServiceException e) { } catch (SdkServiceException e) {
if (e.statusCode() == 404) { if (e.statusCode() == 404) {
return List.of(); return List.of();
@ -947,7 +928,7 @@ class S3BlobContainer extends AbstractBlobContainer {
final var createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder().bucket(bucket).key(blobKey); final var createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder().bucket(bucket).key(blobKey);
S3BlobStore.configureRequestForMetrics(createMultipartUploadRequestBuilder, blobStore, Operation.PUT_MULTIPART_OBJECT, purpose); S3BlobStore.configureRequestForMetrics(createMultipartUploadRequestBuilder, blobStore, Operation.PUT_MULTIPART_OBJECT, purpose);
final var createMultipartUploadRequest = createMultipartUploadRequestBuilder.build(); final var createMultipartUploadRequest = createMultipartUploadRequestBuilder.build();
return SocketAccess.doPrivileged(() -> client.createMultipartUpload(createMultipartUploadRequest)).uploadId(); return client.createMultipartUpload(createMultipartUploadRequest).uploadId();
} }
private String uploadPartAndGetEtag(BytesReference updated, String uploadId) throws IOException { private String uploadPartAndGetEtag(BytesReference updated, String uploadId) throws IOException {
@ -958,12 +939,8 @@ class S3BlobContainer extends AbstractBlobContainer {
uploadPartRequestBuilder.partNumber(1); uploadPartRequestBuilder.partNumber(1);
uploadPartRequestBuilder.sdkPartType(SdkPartType.LAST); uploadPartRequestBuilder.sdkPartType(SdkPartType.LAST);
S3BlobStore.configureRequestForMetrics(uploadPartRequestBuilder, blobStore, Operation.PUT_MULTIPART_OBJECT, purpose); S3BlobStore.configureRequestForMetrics(uploadPartRequestBuilder, blobStore, Operation.PUT_MULTIPART_OBJECT, purpose);
return SocketAccess.doPrivilegedIOException( return client.uploadPart(uploadPartRequestBuilder.build(), RequestBody.fromInputStream(updated.streamInput(), updated.length()))
() -> client.uploadPart( .eTag();
uploadPartRequestBuilder.build(),
RequestBody.fromInputStream(updated.streamInput(), updated.length())
)
).eTag();
} }
private int getUploadIndex(String targetUploadId, List<MultipartUpload> multipartUploads) { private int getUploadIndex(String targetUploadId, List<MultipartUpload> multipartUploads) {
@ -1066,7 +1043,7 @@ class S3BlobContainer extends AbstractBlobContainer {
purpose purpose
); );
final var abortMultipartUploadRequest = abortMultipartUploadRequestBuilder.build(); final var abortMultipartUploadRequest = abortMultipartUploadRequestBuilder.build();
SocketAccess.doPrivilegedVoid(() -> client.abortMultipartUpload(abortMultipartUploadRequest)); client.abortMultipartUpload(abortMultipartUploadRequest);
} catch (SdkServiceException e) { } catch (SdkServiceException e) {
if (e.statusCode() != 404) { if (e.statusCode() != 404) {
throw e; throw e;
@ -1088,7 +1065,7 @@ class S3BlobContainer extends AbstractBlobContainer {
purpose purpose
); );
final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build(); final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build();
SocketAccess.doPrivilegedVoid(() -> client.completeMultipartUpload(completeMultipartUploadRequest)); client.completeMultipartUpload(completeMultipartUploadRequest);
} }
} }
@ -1138,7 +1115,7 @@ class S3BlobContainer extends AbstractBlobContainer {
final var getObjectRequest = getObjectRequestBuilder.build(); final var getObjectRequest = getObjectRequestBuilder.build();
try ( try (
var clientReference = blobStore.clientReference(); var clientReference = blobStore.clientReference();
var s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest)); var s3Object = clientReference.client().getObject(getObjectRequest);
) { ) {
return OptionalBytesReference.of(getRegisterUsingConsistentRead(s3Object, keyPath, key)); return OptionalBytesReference.of(getRegisterUsingConsistentRead(s3Object, keyPath, key));
} catch (Exception attemptException) { } catch (Exception attemptException) {
@ -1180,9 +1157,7 @@ class S3BlobContainer extends AbstractBlobContainer {
b -> b.putRawQueryParameter(S3BlobStore.CUSTOM_QUERY_PARAMETER_PURPOSE, OperationPurpose.SNAPSHOT_DATA.getKey()) b -> b.putRawQueryParameter(S3BlobStore.CUSTOM_QUERY_PARAMETER_PURPOSE, OperationPurpose.SNAPSHOT_DATA.getKey())
) )
.build(); .build();
final var multipartUploadListing = SocketAccess.doPrivileged( final var multipartUploadListing = clientReference.client().listMultipartUploads(listMultipartUploadsRequest);
() -> clientReference.client().listMultipartUploads(listMultipartUploadsRequest)
);
final var multipartUploads = multipartUploadListing.uploads(); final var multipartUploads = multipartUploadListing.uploads();
if (multipartUploads.isEmpty()) { if (multipartUploads.isEmpty()) {
logger.debug("found no multipart uploads to clean up"); logger.debug("found no multipart uploads to clean up");
@ -1237,7 +1212,7 @@ class S3BlobContainer extends AbstractBlobContainer {
while (abortMultipartUploadRequestIterator.hasNext()) { while (abortMultipartUploadRequestIterator.hasNext()) {
final var abortMultipartUploadRequest = abortMultipartUploadRequestIterator.next(); final var abortMultipartUploadRequest = abortMultipartUploadRequestIterator.next();
try { try {
SocketAccess.doPrivilegedVoid(() -> clientReference.client().abortMultipartUpload(abortMultipartUploadRequest)); clientReference.client().abortMultipartUpload(abortMultipartUploadRequest);
logger.info( logger.info(
"cleaned up dangling multipart upload [{}] of blob [{}][{}][{}]", "cleaned up dangling multipart upload [{}] of blob [{}][{}][{}]",
abortMultipartUploadRequest.uploadId(), abortMultipartUploadRequest.uploadId(),

View file

@ -344,9 +344,7 @@ class S3BlobStore implements BlobStore {
int retryCounter = 0; int retryCounter = 0;
while (true) { while (true) {
try (AmazonS3Reference clientReference = clientReference()) { try (AmazonS3Reference clientReference = clientReference()) {
final var response = SocketAccess.doPrivileged( final var response = clientReference.client().deleteObjects(bulkDelete(purpose, this, partition));
() -> clientReference.client().deleteObjects(bulkDelete(purpose, this, partition))
);
if (response.hasErrors()) { if (response.hasErrors()) {
final var exception = new ElasticsearchException(buildDeletionErrorMessage(response.errors())); final var exception = new ElasticsearchException(buildDeletionErrorMessage(response.errors()));
logger.warn(exception.getMessage(), exception); logger.warn(exception.getMessage(), exception);

View file

@ -13,7 +13,6 @@ import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import org.apache.lucene.util.SetOnce; import org.apache.lucene.util.SetOnce;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.cluster.metadata.RepositoryMetadata; import org.elasticsearch.cluster.metadata.RepositoryMetadata;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
@ -32,9 +31,6 @@ import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.NamedXContentRegistry;
import java.io.IOException; import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@ -48,20 +44,6 @@ public class S3RepositoryPlugin extends Plugin implements RepositoryPlugin, Relo
private static final Logger logger = LogManager.getLogger(S3RepositoryPlugin.class); private static final Logger logger = LogManager.getLogger(S3RepositoryPlugin.class);
static {
SpecialPermission.check();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
try {
// Eagerly load the RegionFromEndpointGuesser map from the resource file
MethodHandles.lookup().ensureInitialized(RegionFromEndpointGuesser.class);
} catch (IllegalAccessException unexpected) {
throw new AssertionError(unexpected);
}
return null;
});
}
private final SetOnce<S3Service> service = new SetOnce<>(); private final SetOnce<S3Service> service = new SetOnce<>();
private final Settings settings; private final Settings settings;
@ -97,14 +79,12 @@ public class S3RepositoryPlugin extends Plugin implements RepositoryPlugin, Relo
} }
private static Region getDefaultRegion() { private static Region getDefaultRegion() {
return AccessController.doPrivileged((PrivilegedAction<Region>) () -> { try {
try { return DefaultAwsRegionProviderChain.builder().build().getRegion();
return DefaultAwsRegionProviderChain.builder().build().getRegion(); } catch (Exception e) {
} catch (Exception e) { logger.info("failed to obtain region from default provider chain", e);
logger.info("failed to obtain region from default provider chain", e); return null;
return null; }
}
});
} }
@Override @Override

View file

@ -99,7 +99,7 @@ class S3RetryingInputStream extends InputStream {
} }
this.currentStreamFirstOffset = Math.addExact(start, currentOffset); this.currentStreamFirstOffset = Math.addExact(start, currentOffset);
final var getObjectRequest = getObjectRequestBuilder.build(); final var getObjectRequest = getObjectRequestBuilder.build();
final var getObjectResponse = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest)); final var getObjectResponse = clientReference.client().getObject(getObjectRequest);
this.currentStreamLastOffset = Math.addExact(currentStreamFirstOffset, getStreamLength(getObjectResponse.response())); this.currentStreamLastOffset = Math.addExact(currentStreamFirstOffset, getStreamLength(getObjectResponse.response()));
this.currentStream = getObjectResponse; this.currentStream = getObjectResponse;
return; return;

View file

@ -62,7 +62,6 @@ import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.security.PrivilegedAction;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.Map; import java.util.Map;
@ -229,7 +228,7 @@ class S3Service extends AbstractLifecycleComponent {
// proxy for testing // proxy for testing
S3Client buildClient(final S3ClientSettings clientSettings, SdkHttpClient httpClient) { S3Client buildClient(final S3ClientSettings clientSettings, SdkHttpClient httpClient) {
final S3ClientBuilder s3clientBuilder = buildClientBuilder(clientSettings, httpClient); final S3ClientBuilder s3clientBuilder = buildClientBuilder(clientSettings, httpClient);
return SocketAccess.doPrivileged(s3clientBuilder::build); return s3clientBuilder.build();
} }
protected S3ClientBuilder buildClientBuilder(S3ClientSettings clientSettings, SdkHttpClient httpClient) { protected S3ClientBuilder buildClientBuilder(S3ClientSettings clientSettings, SdkHttpClient httpClient) {
@ -422,20 +421,18 @@ class S3Service extends AbstractLifecycleComponent {
if (credentials == null) { if (credentials == null) {
if (webIdentityTokenCredentialsProvider.isActive()) { if (webIdentityTokenCredentialsProvider.isActive()) {
logger.debug("Using a custom provider chain of Web Identity Token and instance profile credentials"); logger.debug("Using a custom provider chain of Web Identity Token and instance profile credentials");
return new PrivilegedAwsCredentialsProvider( // Wrap the credential providers in ErrorLoggingCredentialsProvider so that we get log info if/when the STS
// Wrap the credential providers in ErrorLoggingCredentialsProvider so that we get log info if/when the STS // (in CustomWebIdentityTokenCredentialsProvider) is unavailable to the ES server, before falling back to a standard
// (in CustomWebIdentityTokenCredentialsProvider) is unavailable to the ES server, before falling back to a standard // credential provider.
// credential provider. return AwsCredentialsProviderChain.builder()
AwsCredentialsProviderChain.builder() // If credentials are refreshed, we want to look around for different forms of credentials again.
// If credentials are refreshed, we want to look around for different forms of credentials again. .reuseLastProviderEnabled(false)
.reuseLastProviderEnabled(false) .addCredentialsProvider(new ErrorLoggingCredentialsProvider(webIdentityTokenCredentialsProvider, LOGGER))
.addCredentialsProvider(new ErrorLoggingCredentialsProvider(webIdentityTokenCredentialsProvider, LOGGER)) .addCredentialsProvider(new ErrorLoggingCredentialsProvider(DefaultCredentialsProvider.create(), LOGGER))
.addCredentialsProvider(new ErrorLoggingCredentialsProvider(DefaultCredentialsProvider.create(), LOGGER)) .build();
.build()
);
} else { } else {
logger.debug("Using DefaultCredentialsProvider for credentials"); logger.debug("Using DefaultCredentialsProvider for credentials");
return new PrivilegedAwsCredentialsProvider(DefaultCredentialsProvider.create()); return DefaultCredentialsProvider.create();
} }
} else { } else {
logger.debug("Using basic key/secret credentials"); logger.debug("Using basic key/secret credentials");
@ -471,46 +468,6 @@ class S3Service extends AbstractLifecycleComponent {
webIdentityTokenCredentialsProvider.close(); webIdentityTokenCredentialsProvider.close();
} }
/**
* Wraps calls with {@link SocketAccess#doPrivileged(PrivilegedAction)} where needed.
*/
static class PrivilegedAwsCredentialsProvider implements AwsCredentialsProvider {
private final AwsCredentialsProvider delegate;
private PrivilegedAwsCredentialsProvider(AwsCredentialsProvider delegate) {
this.delegate = delegate;
}
AwsCredentialsProvider getCredentialsProvider() {
return delegate;
}
@Override
public AwsCredentials resolveCredentials() {
return delegate.resolveCredentials();
}
@Override
public Class<AwsCredentialsIdentity> identityType() {
return delegate.identityType();
}
@Override
public CompletableFuture<AwsCredentialsIdentity> resolveIdentity(ResolveIdentityRequest request) {
return SocketAccess.doPrivileged(() -> delegate.resolveIdentity(request));
}
@Override
public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity(Consumer<ResolveIdentityRequest.Builder> consumer) {
return SocketAccess.doPrivileged(() -> delegate.resolveIdentity(consumer));
}
@Override
public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity() {
return SocketAccess.doPrivileged(delegate::resolveIdentity);
}
}
/** /**
* Customizes {@link StsWebIdentityTokenFileCredentialsProvider}. * Customizes {@link StsWebIdentityTokenFileCredentialsProvider}.
* *
@ -634,7 +591,7 @@ class S3Service extends AbstractLifecycleComponent {
public void onFileChanged(Path file) { public void onFileChanged(Path file) {
if (file.equals(webIdentityTokenFileSymlink)) { if (file.equals(webIdentityTokenFileSymlink)) {
LOGGER.debug("WS web identity token file [{}] changed, updating credentials", file); LOGGER.debug("WS web identity token file [{}] changed, updating credentials", file);
SocketAccess.doPrivilegedVoid(credentialsProvider::resolveCredentials); credentialsProvider.resolveCredentials();
} }
} }
}); });
@ -676,19 +633,19 @@ class S3Service extends AbstractLifecycleComponent {
@Override @Override
public CompletableFuture<AwsCredentialsIdentity> resolveIdentity(ResolveIdentityRequest request) { public CompletableFuture<AwsCredentialsIdentity> resolveIdentity(ResolveIdentityRequest request) {
Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set"); Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set");
return SocketAccess.doPrivileged(() -> credentialsProvider.resolveIdentity(request)); return credentialsProvider.resolveIdentity(request);
} }
@Override @Override
public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity(Consumer<ResolveIdentityRequest.Builder> consumer) { public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity(Consumer<ResolveIdentityRequest.Builder> consumer) {
Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set"); Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set");
return SocketAccess.doPrivileged(() -> credentialsProvider.resolveIdentity(consumer)); return credentialsProvider.resolveIdentity(consumer);
} }
@Override @Override
public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity() { public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity() {
Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set"); Objects.requireNonNull(credentialsProvider, "credentialsProvider is not set");
return SocketAccess.doPrivileged(credentialsProvider::resolveIdentity); return credentialsProvider.resolveIdentity();
} }
} }
@ -737,17 +694,17 @@ class S3Service extends AbstractLifecycleComponent {
@Override @Override
public CompletableFuture<AwsCredentialsIdentity> resolveIdentity(ResolveIdentityRequest request) { public CompletableFuture<AwsCredentialsIdentity> resolveIdentity(ResolveIdentityRequest request) {
return SocketAccess.doPrivileged(() -> delegate.resolveIdentity(request).handle(this::resultHandler)); return delegate.resolveIdentity(request).handle(this::resultHandler);
} }
@Override @Override
public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity(Consumer<ResolveIdentityRequest.Builder> consumer) { public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity(Consumer<ResolveIdentityRequest.Builder> consumer) {
return SocketAccess.doPrivileged(() -> delegate.resolveIdentity(consumer).handle(this::resultHandler)); return delegate.resolveIdentity(consumer).handle(this::resultHandler);
} }
@Override @Override
public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity() { public CompletableFuture<? extends AwsCredentialsIdentity> resolveIdentity() {
return SocketAccess.doPrivileged(() -> delegate.resolveIdentity().handle(this::resultHandler)); return delegate.resolveIdentity().handle(this::resultHandler);
} }
@Override @Override

View file

@ -1,52 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.repositories.s3;
import org.elasticsearch.SpecialPermission;
import java.io.IOException;
import java.net.SocketPermission;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
/**
* This plugin uses aws libraries to connect to S3 repositories. For these remote calls the plugin needs
* {@link SocketPermission} 'connect' to establish connections. This class wraps the operations requiring access in
* {@link AccessController#doPrivileged(PrivilegedAction)} blocks.
*/
final class SocketAccess {
private SocketAccess() {}
public static <T> T doPrivileged(PrivilegedAction<T> operation) {
SpecialPermission.check();
return AccessController.doPrivileged(operation);
}
public static <T> T doPrivilegedIOException(PrivilegedExceptionAction<T> operation) throws IOException {
SpecialPermission.check();
try {
return AccessController.doPrivileged(operation);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
public static void doPrivilegedVoid(Runnable action) {
SpecialPermission.check();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
action.run();
return null;
});
}
}

View file

@ -61,9 +61,7 @@ public class AwsS3ServiceImplTests extends ESTestCase {
clientSettings, clientSettings,
webIdentityTokenCredentialsProvider webIdentityTokenCredentialsProvider
); );
assertThat(credentialsProvider, instanceOf(S3Service.PrivilegedAwsCredentialsProvider.class)); assertThat(credentialsProvider, instanceOf(DefaultCredentialsProvider.class));
var privilegedAWSCredentialsProvider = (S3Service.PrivilegedAwsCredentialsProvider) credentialsProvider;
assertThat(privilegedAWSCredentialsProvider.getCredentialsProvider(), instanceOf(DefaultCredentialsProvider.class));
} }
public void testSupportsWebIdentityTokenCredentials() { public void testSupportsWebIdentityTokenCredentials() {
@ -80,10 +78,8 @@ public class AwsS3ServiceImplTests extends ESTestCase {
S3ClientSettings.getClientSettings(Settings.EMPTY, randomAlphaOfLength(8).toLowerCase(Locale.ROOT)), S3ClientSettings.getClientSettings(Settings.EMPTY, randomAlphaOfLength(8).toLowerCase(Locale.ROOT)),
webIdentityTokenCredentialsProvider webIdentityTokenCredentialsProvider
); );
assertThat(credentialsProvider, instanceOf(S3Service.PrivilegedAwsCredentialsProvider.class)); assertThat(credentialsProvider, instanceOf(AwsCredentialsProviderChain.class));
var privilegedAWSCredentialsProvider = (S3Service.PrivilegedAwsCredentialsProvider) credentialsProvider; AwsCredentials resolvedCredentials = credentialsProvider.resolveCredentials();
assertThat(privilegedAWSCredentialsProvider.getCredentialsProvider(), instanceOf(AwsCredentialsProviderChain.class));
AwsCredentials resolvedCredentials = privilegedAWSCredentialsProvider.resolveCredentials();
assertEquals("sts_access_key_id", resolvedCredentials.accessKeyId()); assertEquals("sts_access_key_id", resolvedCredentials.accessKeyId());
assertEquals("sts_secret_key", resolvedCredentials.secretAccessKey()); assertEquals("sts_secret_key", resolvedCredentials.secretAccessKey());
} }
@ -122,9 +118,7 @@ public class AwsS3ServiceImplTests extends ESTestCase {
defaultClientSettings, defaultClientSettings,
webIdentityTokenCredentialsProvider webIdentityTokenCredentialsProvider
); );
assertThat(defaultCredentialsProvider, instanceOf(S3Service.PrivilegedAwsCredentialsProvider.class)); assertThat(defaultCredentialsProvider, instanceOf(DefaultCredentialsProvider.class));
var privilegedAWSCredentialsProvider = (S3Service.PrivilegedAwsCredentialsProvider) defaultCredentialsProvider;
assertThat(privilegedAWSCredentialsProvider.getCredentialsProvider(), instanceOf(DefaultCredentialsProvider.class));
} }
public void testBasicAccessKeyAndSecretKeyCredentials() { public void testBasicAccessKeyAndSecretKeyCredentials() {

View file

@ -28,9 +28,6 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.URL; import java.net.URL;
import java.nio.file.NoSuchFileException; import java.nio.file.NoSuchFileException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
@ -158,11 +155,7 @@ public class URLBlobContainer extends AbstractBlobContainer {
@SuppressForbidden(reason = "We call connect in doPrivileged and provide SocketPermission") @SuppressForbidden(reason = "We call connect in doPrivileged and provide SocketPermission")
private static InputStream getInputStream(URL url) throws IOException { private static InputStream getInputStream(URL url) throws IOException {
try { return url.openStream();
return AccessController.doPrivileged((PrivilegedExceptionAction<InputStream>) url::openStream);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
} }
@Override @Override

View file

@ -19,9 +19,6 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.nio.file.NoSuchFileException; import java.nio.file.NoSuchFileException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -221,44 +218,36 @@ class RetryingHttpInputStream extends InputStream {
private HttpResponseInputStream openInputStream() throws IOException { private HttpResponseInputStream openInputStream() throws IOException {
try { try {
return AccessController.doPrivileged((PrivilegedExceptionAction<HttpResponseInputStream>) () -> { final Map<String, String> headers = Maps.newMapWithExpectedSize(1);
final Map<String, String> headers = Maps.newMapWithExpectedSize(1);
if (isRangeRead()) { if (isRangeRead()) {
headers.put("Range", getBytesRange(Math.addExact(start, totalBytesRead), end)); headers.put("Range", getBytesRange(Math.addExact(start, totalBytesRead), end));
}
try {
final URLHttpClient.HttpResponse response = httpClient.get(blobURI, headers);
final int statusCode = response.getStatusCode();
if (statusCode != RestStatus.OK.getStatus() && statusCode != RestStatus.PARTIAL_CONTENT.getStatus()) {
String body = response.getBodyAsString(MAX_ERROR_MESSAGE_BODY_SIZE);
IOUtils.closeWhileHandlingException(response);
throw new IOException(
getErrorMessage(
"The server returned an invalid response:" + " Status code: [" + statusCode + "] - Body: " + body
)
);
}
currentStreamLastOffset = Math.addExact(Math.addExact(start, totalBytesRead), getStreamLength(response));
return response.getInputStream();
} catch (URLHttpClientException e) {
if (e.getStatusCode() == RestStatus.NOT_FOUND.getStatus()) {
throw new NoSuchFileException("blob object [" + blobName + "] not found");
} else {
throw e;
}
}
});
} catch (PrivilegedActionException e) {
final Throwable cause = e.getCause();
if (cause instanceof IOException ioException) {
throw ioException;
} }
throw new IOException(getErrorMessage(), e);
try {
final URLHttpClient.HttpResponse response = httpClient.get(blobURI, headers);
final int statusCode = response.getStatusCode();
if (statusCode != RestStatus.OK.getStatus() && statusCode != RestStatus.PARTIAL_CONTENT.getStatus()) {
String body = response.getBodyAsString(MAX_ERROR_MESSAGE_BODY_SIZE);
IOUtils.closeWhileHandlingException(response);
throw new IOException(
getErrorMessage("The server returned an invalid response:" + " Status code: [" + statusCode + "] - Body: " + body)
);
}
currentStreamLastOffset = Math.addExact(Math.addExact(start, totalBytesRead), getStreamLength(response));
return response.getInputStream();
} catch (URLHttpClientException e) {
if (e.getStatusCode() == RestStatus.NOT_FOUND.getStatus()) {
throw new NoSuchFileException("blob object [" + blobName + "] not found");
} else {
throw e;
}
}
} catch (IOException e) {
throw e;
} catch (Exception e) { } catch (Exception e) {
throw new IOException(getErrorMessage(), e); throw new IOException(getErrorMessage(), e);
} }

View file

@ -28,8 +28,6 @@ import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays; import java.util.Arrays;
import java.util.Map; import java.util.Map;
@ -246,9 +244,7 @@ public class URLHttpClientTests extends ESTestCase {
} }
private URLHttpClient.HttpResponse executeRequest(String endpoint) throws Exception { private URLHttpClient.HttpResponse executeRequest(String endpoint) throws Exception {
return AccessController.doPrivileged((PrivilegedExceptionAction<URLHttpClient.HttpResponse>) () -> { return httpClient.get(getURIForEndpoint(endpoint), Map.of());
return httpClient.get(getURIForEndpoint(endpoint), Map.of());
});
} }
private URI getURIForEndpoint(String endpoint) throws Exception { private URI getURIForEndpoint(String endpoint) throws Exception {

View file

@ -18,8 +18,6 @@ import org.elasticsearch.grok.GrokBuiltinPatterns;
import org.elasticsearch.grok.MatcherWatchdog; import org.elasticsearch.grok.MatcherWatchdog;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -95,25 +93,21 @@ public interface NamedGroupExtractor {
* Build the grok pattern in a PrivilegedAction so it can load * Build the grok pattern in a PrivilegedAction so it can load
* things from the classpath. * things from the classpath.
*/ */
Grok grok = AccessController.doPrivileged(new PrivilegedAction<Grok>() { Grok grok;
@Override try {
public Grok run() { // Try to collect warnings up front and refuse to compile the expression if there are any
try { List<String> warnings = new ArrayList<>();
// Try to collect warnings up front and refuse to compile the expression if there are any new Grok(GrokBuiltinPatterns.legacyPatterns(), pattern, watchdog, warnings::add).match("__nomatch__");
List<String> warnings = new ArrayList<>(); if (false == warnings.isEmpty()) {
new Grok(GrokBuiltinPatterns.legacyPatterns(), pattern, watchdog, warnings::add).match("__nomatch__"); throw new IllegalArgumentException("emitted warnings: " + warnings);
if (false == warnings.isEmpty()) {
throw new IllegalArgumentException("emitted warnings: " + warnings);
}
return new Grok(GrokBuiltinPatterns.legacyPatterns(), pattern, watchdog, w -> {
throw new IllegalArgumentException("grok [" + pattern + "] emitted a warning: " + w);
});
} catch (RuntimeException e) {
throw new IllegalArgumentException("error compiling grok pattern [" + pattern + "]: " + e.getMessage(), e);
}
} }
});
grok = new Grok(GrokBuiltinPatterns.legacyPatterns(), pattern, watchdog, w -> {
throw new IllegalArgumentException("grok [" + pattern + "] emitted a warning: " + w);
});
} catch (RuntimeException e) {
throw new IllegalArgumentException("error compiling grok pattern [" + pattern + "]: " + e.getMessage(), e);
}
return new NamedGroupExtractor() { return new NamedGroupExtractor() {
@Override @Override
public Map<String, ?> extract(String in) { public Map<String, ?> extract(String in) {