Introduce RefCounted#mustIncRef (#102515)

In several places we acquire a ref to a resource that we are certain is
not closed, so this commit adds a utility for asserting this to be the
case. This also helps a little with mocks since boolean methods like
`tryIncRef()` return `false` on mock objects by default, but void
methods like `mustIncRef()` default to being a no-op.
This commit is contained in:
David Turner 2023-11-23 21:40:43 +00:00 committed by GitHub
parent 0ecb2af13d
commit b2127ec2f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 51 additions and 62 deletions

View file

@ -19,6 +19,7 @@ import java.util.Objects;
public abstract class AbstractRefCounted implements RefCounted {
public static final String ALREADY_CLOSED_MESSAGE = "already closed, can't increment ref count";
public static final String INVALID_DECREF_MESSAGE = "invalid decRef call: already closed";
private static final VarHandle VH_REFCOUNT_FIELD;
@ -63,7 +64,7 @@ public abstract class AbstractRefCounted implements RefCounted {
public final boolean decRef() {
touch();
int i = (int) VH_REFCOUNT_FIELD.getAndAdd(this, -1);
assert i > 0 : "invalid decRef call: already closed";
assert i > 0 : INVALID_DECREF_MESSAGE;
if (i == 1) {
try {
closeInternal();

View file

@ -62,4 +62,16 @@ public interface RefCounted {
* @return whether there are currently any active references to this object.
*/
boolean hasReferences();
/**
* Similar to {@link #incRef()} except that it also asserts that it managed to acquire the ref, for use in situations where it is a bug
* if all refs have been released.
*/
default void mustIncRef() {
if (tryIncRef()) {
return;
}
assert false : AbstractRefCounted.ALREADY_CLOSED_MESSAGE;
incRef(); // throws an ISE
}
}

View file

@ -180,15 +180,11 @@ public class GeoIpDownloaderTests extends ESTestCase {
public void testIndexChunksNoData() throws IOException {
client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
var flushResponse = mock(FlushResponse.class);
when(flushResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(flushResponse);
flushResponseActionListener.onResponse(mock(FlushResponse.class));
});
client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
var refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(refreshResponse);
flushResponseActionListener.onResponse(mock(RefreshResponse.class));
});
InputStream empty = new ByteArrayInputStream(new byte[0]);
@ -198,15 +194,11 @@ public class GeoIpDownloaderTests extends ESTestCase {
public void testIndexChunksMd5Mismatch() {
client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
var flushResponse = mock(FlushResponse.class);
when(flushResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(flushResponse);
flushResponseActionListener.onResponse(mock(FlushResponse.class));
});
client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
var refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(refreshResponse);
flushResponseActionListener.onResponse(mock(RefreshResponse.class));
});
IOException exception = expectThrows(
@ -238,21 +230,15 @@ public class GeoIpDownloaderTests extends ESTestCase {
assertEquals("test", source.get("name"));
assertArrayEquals(chunksData[chunk], (byte[]) source.get("data"));
assertEquals(chunk + 15, source.get("chunk"));
var indexResponse = mock(IndexResponse.class);
when(indexResponse.hasReferences()).thenReturn(true);
listener.onResponse(indexResponse);
listener.onResponse(mock(IndexResponse.class));
});
client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
var flushResponse = mock(FlushResponse.class);
when(flushResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(flushResponse);
flushResponseActionListener.onResponse(mock(FlushResponse.class));
});
client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
var refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
flushResponseActionListener.onResponse(refreshResponse);
flushResponseActionListener.onResponse(mock(RefreshResponse.class));
});
InputStream big = new ByteArrayInputStream(bigArray);

View file

@ -135,7 +135,7 @@ class S3Service implements Closeable {
return existing;
}
final AmazonS3Reference clientReference = new AmazonS3Reference(buildClient(clientSettings));
clientReference.incRef();
clientReference.mustIncRef();
clientsCache = Maps.copyMapWithAddedEntry(clientsCache, clientSettings, clientReference);
return clientReference;
}

View file

@ -63,7 +63,6 @@ import org.elasticsearch.core.Releasables;
public final class RefCountingRunnable implements Releasable {
private static final Logger logger = LogManager.getLogger(RefCountingRunnable.class);
static final String ALREADY_CLOSED_MESSAGE = "already closed, cannot acquire or release any further refs";
private final RefCounted refCounted;
@ -86,15 +85,12 @@ public final class RefCountingRunnable implements Releasable {
* will be ignored otherwise. This deviates from the contract of {@link java.io.Closeable}.
*/
public Releasable acquire() {
if (refCounted.tryIncRef()) {
// All refs are considered equal so there's no real need to allocate a new object here, although note that this deviates
// (subtly) from the docs for Closeable#close() which indicate that it should be idempotent. But only if assertions are
// disabled, and if assertions are enabled then we are asserting that we never double-close these things anyway.
refCounted.mustIncRef();
// All refs are considered equal so there's no real need to allocate a new object here, although note that this deviates (subtly)
// from the docs for Closeable#close() which indicate that it should be idempotent. But only if assertions are disabled, and if
// assertions are enabled then we are asserting that we never double-close these things anyway.
return Releasables.assertOnce(this);
}
assert false : ALREADY_CLOSED_MESSAGE;
throw new IllegalStateException(ALREADY_CLOSED_MESSAGE);
}
/**
* Acquire a reference to this object and return a listener which releases it when notified. The delegate {@link Runnable} is called

View file

@ -228,7 +228,7 @@ public abstract class TransportBroadcastByNodeAction<
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
// workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
request.incRef();
request.mustIncRef();
executor.execute(ActionRunnable.wrapReleasing(listener, request::decRef, l -> doExecuteForked(task, request, listener)));
}
@ -474,7 +474,7 @@ public abstract class TransportBroadcastByNodeAction<
}
NodeRequest(Request indicesLevelRequest, List<ShardRouting> shards, String nodeId) {
indicesLevelRequest.incRef();
indicesLevelRequest.mustIncRef();
this.indicesLevelRequest = indicesLevelRequest;
this.shards = shards;
this.nodeId = nodeId;

View file

@ -169,7 +169,7 @@ public abstract class TransportMasterNodeAction<Request extends MasterNodeReques
if (task != null) {
request.setParentTask(clusterService.localNode().getId(), task.getId());
}
request.incRef();
request.mustIncRef();
new AsyncSingleAction(task, request, ActionListener.runBefore(listener, request::decRef)).doStart(state);
}

View file

@ -290,7 +290,7 @@ public abstract class TransportTasksAction<
protected NodeTaskRequest(TasksRequest tasksRequest) {
super();
tasksRequest.incRef();
tasksRequest.mustIncRef();
this.tasksRequest = tasksRequest;
}

View file

@ -1612,9 +1612,8 @@ public abstract class AbstractClient implements Client {
@Override
public final void onResponse(R result) {
assert result.hasReferences();
if (set(result)) {
result.incRef();
result.mustIncRef();
}
}

View file

@ -363,8 +363,7 @@ public class JoinValidationService {
);
return;
}
assert bytes.hasReferences() : "already closed";
bytes.incRef();
bytes.mustIncRef();
transportService.sendRequest(
connection,
JOIN_VALIDATE_ACTION_NAME,

View file

@ -192,7 +192,7 @@ public abstract class CancellableSingleObjectCache<Input, Key, Value> {
CachedItem(Key key) {
this.key = key;
incRef(); // start with a refcount of 2 so we're not closed while adding the first listener
mustIncRef(); // start with a refcount of 2 so we're not closed while adding the first listener
this.future.addListener(new ActionListener<>() {
@Override
public void onResponse(Value value) {

View file

@ -88,7 +88,7 @@ public class ThrottledIterator<T> implements Releasable {
}
}
try (var itemRefs = new ItemRefCounted()) {
itemRefs.incRef();
itemRefs.mustIncRef();
itemConsumer.accept(Releasables.releaseOnce(itemRefs::decRef), item);
} catch (Exception e) {
logger.error(Strings.format("exception when processing [%s] with [%s]", item, itemConsumer), e);
@ -108,7 +108,7 @@ public class ThrottledIterator<T> implements Releasable {
private boolean isRecursive = true;
ItemRefCounted() {
refs.incRef();
refs.mustIncRef();
}
@Override

View file

@ -223,7 +223,7 @@ public class ClusterConnectionManager implements ConnectionManager {
IOUtils.closeWhileHandlingException(conn);
} else {
logger.debug("connected to node [{}]", node);
managerRefs.incRef();
managerRefs.mustIncRef();
try {
connectionListener.onNodeConnected(node, conn);
} finally {

View file

@ -293,7 +293,7 @@ public class InboundHandler {
private <T extends TransportRequest> void handleRequestForking(T request, RequestHandlerRegistry<T> reg, TransportChannel channel) {
boolean success = false;
request.incRef();
request.mustIncRef();
try {
reg.getExecutor().execute(threadPool.getThreadContext().preserveContextWithTracing(new AbstractRunnable() {
@Override
@ -381,7 +381,7 @@ public class InboundHandler {
// no need to provide a buffer release here, we never escape the buffer when handling directly
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
} else {
inboundMessage.incRef();
inboundMessage.mustIncRef();
// release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
executor.execute(new ForkingResponseHandlerRunnable(handler, null, threadPool) {

View file

@ -65,7 +65,7 @@ public final class TransportActionProxy {
@Override
public void handleResponse(TransportResponse response) {
try {
response.incRef();
response.mustIncRef();
channel.sendResponse(response);
} catch (IOException e) {
throw new UncheckedIOException(e);

View file

@ -1013,7 +1013,7 @@ public class TransportService extends AbstractLifecycleComponent
}
} else {
boolean success = false;
request.incRef();
request.mustIncRef();
try {
executor.execute(threadPool.getThreadContext().preserveContextWithTracing(new AbstractRunnable() {
@Override
@ -1479,7 +1479,7 @@ public class TransportService extends AbstractLifecycleComponent
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
processResponse(handler, response);
} else {
response.incRef();
response.mustIncRef();
executor.execute(new ForkingResponseHandlerRunnable(handler, null, threadPool) {
@Override
protected void doRun() {

View file

@ -11,6 +11,7 @@ package org.elasticsearch.action.support;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.ReachabilityChecker;
@ -174,10 +175,10 @@ public class RefCountingListenerTests extends ESTestCase {
final String expectedMessage;
if (randomBoolean()) {
throwingRunnable = refs::acquire;
expectedMessage = RefCountingRunnable.ALREADY_CLOSED_MESSAGE;
expectedMessage = AbstractRefCounted.ALREADY_CLOSED_MESSAGE;
} else {
throwingRunnable = refs::close;
expectedMessage = "invalid decRef call: already closed";
expectedMessage = AbstractRefCounted.INVALID_DECREF_MESSAGE;
}
assertEquals(expectedMessage, expectThrows(AssertionError.class, throwingRunnable).getMessage());

View file

@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.test.ESTestCase;
@ -166,10 +167,10 @@ public class RefCountingRunnableTests extends ESTestCase {
final String expectedMessage;
if (randomBoolean()) {
throwingRunnable = randomBoolean() ? refs::acquire : refs::acquireListener;
expectedMessage = RefCountingRunnable.ALREADY_CLOSED_MESSAGE;
expectedMessage = AbstractRefCounted.ALREADY_CLOSED_MESSAGE;
} else {
throwingRunnable = refs::close;
expectedMessage = "invalid decRef call: already closed";
expectedMessage = AbstractRefCounted.INVALID_DECREF_MESSAGE;
}
assertEquals(expectedMessage, expectThrows(AssertionError.class, throwingRunnable).getMessage());

View file

@ -150,7 +150,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
assert destinationTransport.getLocalNode().equals(getLocalNode()) == false
: "non-local message from " + getLocalNode() + " to itself";
request.incRef();
request.mustIncRef();
destinationTransport.execute(new RebootSensitiveRunnable() {
@Override

View file

@ -140,9 +140,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
doAnswer(invocationOnMock -> {
ActionListener<ClearScrollResponse> listener = (ActionListener<ClearScrollResponse>) invocationOnMock.getArguments()[2];
wasScrollCleared = true;
var clearScrollResponse = mock(ClearScrollResponse.class);
when(clearScrollResponse.hasReferences()).thenReturn(true);
listener.onResponse(clearScrollResponse);
listener.onResponse(mock(ClearScrollResponse.class));
return null;
}).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any());
}
@ -173,7 +171,6 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
protected SearchResponse createSearchResponseWithHits(String... hits) {
SearchHits searchHits = createHits(hits);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.hasReferences()).thenReturn(true);
when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
when(searchResponse.getHits()).thenReturn(searchHits);
return searchResponse;

View file

@ -543,7 +543,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
AbstractRunnable getReceiveRunnable(T request, TransportChannel channel, Task task) {
final Runnable releaseRequest = new RunOnce(request::decRef);
request.incRef();
request.mustIncRef();
return new AbstractRunnable() {
@Override
public boolean isForceExecution() {

View file

@ -163,7 +163,6 @@ public class WatcherServiceTests extends ESTestCase {
// response setup, successful refresh response
RefreshResponse refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.hasReferences()).thenReturn(true);
when(refreshResponse.getSuccessfulShards()).thenReturn(
clusterState.getMetadata().getIndices().get(Watch.INDEX).getNumberOfShards()
);

View file

@ -210,7 +210,6 @@ public class TriggeredWatchStoreTests extends ESTestCase {
SearchResponse searchResponse1 = mock(SearchResponse.class);
when(searchResponse1.getSuccessfulShards()).thenReturn(1);
when(searchResponse1.getTotalShards()).thenReturn(1);
when(searchResponse1.hasReferences()).thenReturn(true);
BytesArray source = new BytesArray("{}");
SearchHit hit = new SearchHit(0, "first_foo");
hit.version(1L);
@ -513,7 +512,6 @@ public class TriggeredWatchStoreTests extends ESTestCase {
RefreshResponse refreshResponse = mock(RefreshResponse.class);
when(refreshResponse.getTotalShards()).thenReturn(total);
when(refreshResponse.getSuccessfulShards()).thenReturn(successful);
when(refreshResponse.hasReferences()).thenReturn(true);
return refreshResponse;
}
}