From 7f3d0b649adcd915c01da3baa73935e0e880fbb5 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 2 Feb 2026 11:03:14 +0000 Subject: [PATCH] Normalize ByteBuf retention implementations - Updated `NettyByteBuf` to maintain its own references - Updated `CompositeByteBuf` to retain and release the nested ByteBufs - Updated `ByteBufTest` and `CompositeByteBuf` to improve test coverage JAVA-6009 --- .../internal/connection/CompositeByteBuf.java | 78 ++- .../connection/netty/NettyByteBuf.java | 14 +- .../connection/ByteBufSpecification.groovy | 255 ---------- .../internal/connection/ByteBufTest.java | 460 +++++++++++++++++- .../connection/CompositeByteBufTest.java | 357 ++++++++++---- .../connection/TestBufferProviders.java | 407 ++++++++++++++++ 6 files changed, 1198 insertions(+), 373 deletions(-) delete mode 100644 driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy create mode 100644 driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java diff --git a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java index 67121ecbe3c..e3eb299a40e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java @@ -24,17 +24,55 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import static java.lang.String.format; import static org.bson.assertions.Assertions.isTrueArgument; import static org.bson.assertions.Assertions.notNull; -class CompositeByteBuf implements ByteBuf { +/** + * A composite {@link ByteBuf} that provides a unified view over a list of component buffers. + * + *

ByteBuf Ownership and Reference Counting

+ *

This class manages the lifecycle of its component buffers with the following rules:

+ * + * + *

Important: The composite's reference count is independent from its components' + * reference counts, but they are kept in sync via {@link #retain()} and {@link #release()} methods.

+ * + *

This class is not part of the public API and may be removed or changed at any time

+ */ +final class CompositeByteBuf implements ByteBuf { private final List components; private final AtomicInteger referenceCount = new AtomicInteger(1); private int position; private int limit; + /** + * Creates a composite buffer from the given list of buffers. + * + *

ByteBuf Ownership: This constructor does NOT take ownership of the input buffers. + * It calls {@link ByteBuf#asReadOnly()} on each buffer, which creates a shallow read-only view that + * shares the same underlying data and reference count as the original. The caller retains ownership + * of the original buffers and is responsible for their lifecycle.

+ * + *

The composite starts with a reference count of 1. When {@link #release()} is called and the + * reference count reaches 0, it does NOT automatically release the original buffers - the caller + * must handle that separately.

+ * + * @param buffers the list of buffers to compose, must not be null or empty + */ CompositeByteBuf(final List buffers) { notNull("buffers", buffers); isTrueArgument("buffer list not empty", !buffers.isEmpty()); @@ -50,7 +88,9 @@ class CompositeByteBuf implements ByteBuf { } private CompositeByteBuf(final CompositeByteBuf from) { - components = from.components; + components = from.components.stream().map(c -> + new Component(c.buffer.retain(), c.offset)) + .collect(Collectors.toList()); position = from.position(); limit = from.limit(); } @@ -277,6 +317,19 @@ public ByteBuf asReadOnly() { throw new UnsupportedOperationException(); } + /** + * Creates a duplicate of this composite buffer with independent position and limit. + * + *

ByteBuf Ownership: The duplicate calls {@link ByteBuf#retain()} on each + * component buffer, incrementing their reference counts. The duplicate owns these retained + * references and starts with its own reference count of 1. The caller is responsible + * for releasing the duplicate when done, which will release the component buffers.

+ * + *

The duplicate shares the underlying buffer data with this composite but has independent + * reference counting and position/limit state.

+ * + * @return a new composite buffer that shares data with this one but has independent state + */ @Override public ByteBuf duplicate() { return new CompositeByteBuf(this); @@ -300,21 +353,42 @@ public int getReferenceCount() { return referenceCount.get(); } + /** + * Increments the reference count of this composite buffer and all component buffers. + * + *

Important: This method retains both the composite's reference count and all + * component buffers. If the reference count is already 0, an {@link IllegalStateException} is thrown. + * Note that if an exception is thrown, the component buffers will have been retained before the + * exception occurs.

+ * + * @return this buffer + * @throws IllegalStateException if the reference count is already 0 + */ @Override public ByteBuf retain() { if (referenceCount.incrementAndGet() == 1) { referenceCount.decrementAndGet(); throw new IllegalStateException("Attempted to increment the reference count when it is already 0"); } + components.forEach(c -> c.buffer.retain()); return this; } + /** + * Decrements the reference count of this composite buffer and all component buffers. + * + *

Important: This method releases both the composite's reference count and all + * component buffers. All component buffers are released even if an exception occurs.

+ * + * @throws IllegalStateException if the reference count is already 0 + */ @Override public void release() { if (referenceCount.decrementAndGet() < 0) { referenceCount.incrementAndGet(); throw new IllegalStateException("Attempted to decrement the reference count below 0"); } + components.forEach(c -> c.buffer.release()); } private Component findComponent(final int index) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java index 72235b46760..965ac9e951a 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java @@ -20,12 +20,13 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.concurrent.atomic.AtomicInteger; /** *

This class is not part of the public API and may be removed or changed at any time

*/ public final class NettyByteBuf implements ByteBuf { - + private final AtomicInteger referenceCount = new AtomicInteger(1); private io.netty.buffer.ByteBuf proxied; private boolean isWriting = true; @@ -271,17 +272,26 @@ public ByteBuffer asNIO() { @Override public int getReferenceCount() { - return proxied.refCnt(); + return referenceCount.get(); } @Override public ByteBuf retain() { + if (referenceCount.incrementAndGet() == 1) { + referenceCount.decrementAndGet(); + throw new IllegalStateException("Attempted to increment the reference count when it is already 0"); + } proxied.retain(); return this; } @Override public void release() { + int newRefCount = referenceCount.decrementAndGet(); + if (newRefCount < 0) { + referenceCount.incrementAndGet(); + throw new IllegalStateException("Attempted to decrement the reference count below 0"); + } proxied.release(); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy deleted file mode 100644 index d052d6b23f1..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection - - -import com.mongodb.internal.connection.netty.NettyByteBuf -import io.netty.buffer.ByteBufAllocator -import io.netty.buffer.PooledByteBufAllocator -import org.bson.ByteBuf -import spock.lang.Specification - -class ByteBufSpecification extends Specification { - def 'should put a byte'() { - given: - def buffer = provider.getBuffer(1024) - - when: - buffer.put((byte) 42) - buffer.flip() - - then: - buffer.get() == 42 - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'should put several bytes'() { - given: - def buffer = provider.getBuffer(1024) - - when: - buffer.with { - put((byte) 42) - put((byte) 43) - put((byte) 44) - flip() - } - - then: - buffer.get() == 42 - buffer.get() == 43 - buffer.get() == 44 - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'should put bytes at index'() { - given: - def buffer = provider.getBuffer(1024) - - when: - buffer.with { - put((byte) 0) - put((byte) 0) - put((byte) 0) - put((byte) 0) - put((byte) 43) - put((byte) 44) - put(0, (byte) 22) - put(1, (byte) 23) - put(2, (byte) 24) - put(3, (byte) 25) - flip() - } - - then: - buffer.get() == 22 - buffer.get() == 23 - buffer.get() == 24 - buffer.get() == 25 - buffer.get() == 43 - buffer.get() == 44 - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'when writing, remaining should be the number of bytes that can be written'() { - when: - def buffer = provider.getBuffer(1024) - - then: - buffer.remaining() == 1024 - - and: - buffer.put((byte) 1) - - then: - buffer.remaining() == 1023 - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'when writing, hasRemaining should be true if there is still room to write'() { - when: - def buffer = provider.getBuffer(2) - - then: - buffer.hasRemaining() - - and: - buffer.put((byte) 1) - - then: - buffer.hasRemaining() - - and: - buffer.put((byte) 1) - - then: - !buffer.hasRemaining() - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'should return NIO buffer with the same capacity and limit'() { - given: - def buffer = provider.getBuffer(36) - - when: - def nioBuffer = buffer.asNIO() - - then: - nioBuffer.limit() == 36 - nioBuffer.position() == 0 - nioBuffer.remaining() == 36 - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'should return NIO buffer with the same contents'() { - given: - def buffer = provider.getBuffer(1024) - - buffer.with { - put((byte) 42) - put((byte) 43) - put((byte) 44) - put((byte) 45) - put((byte) 46) - put((byte) 47) - - flip() - } - - when: - def nioBuffer = buffer.asNIO() - - then: - nioBuffer.limit() == 6 - nioBuffer.position() == 0 - nioBuffer.get() == 42 - nioBuffer.get() == 43 - nioBuffer.get() == 44 - nioBuffer.get() == 45 - nioBuffer.get() == 46 - nioBuffer.get() == 47 - nioBuffer.remaining() == 0 - - cleanup: - buffer.release() - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - def 'should enforce reference counts'() { - when: - def buffer = provider.getBuffer(1024) - buffer.put((byte) 1) - - then: - buffer.referenceCount == 1 - - when: - buffer.retain() - buffer.put((byte) 1) - - then: - buffer.referenceCount == 2 - - when: - buffer.release() - buffer.put((byte) 1) - - then: - buffer.referenceCount == 1 - - when: - buffer.release() - - then: - buffer.referenceCount == 0 - - when: - buffer.put((byte) 1) - - then: - thrown(Exception) - - where: - provider << [new NettyBufferProvider(), new SimpleBufferProvider()] - } - - static final class NettyBufferProvider implements BufferProvider { - private final ByteBufAllocator allocator - - NettyBufferProvider() { - allocator = PooledByteBufAllocator.DEFAULT - } - - @Override - ByteBuf getBuffer(final int size) { - io.netty.buffer.ByteBuf buffer = allocator.directBuffer(size, size) - new NettyByteBuf(buffer) - } - } -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java index 722d7d62fa4..f8d71fae3e3 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java @@ -18,24 +18,33 @@ import org.bson.ByteBuf; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import java.nio.ByteBuffer; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static com.mongodb.internal.connection.TestBufferProviders.TrackingBufferProvider; class ByteBufTest { - static Stream bufferProviders() { - return Stream.of(new ByteBufSpecification.NettyBufferProvider(), new SimpleBufferProvider()); + static Stream bufferProviders() { + return TestBufferProviders.bufferProviders(); } - @ParameterizedTest + @DisplayName("Should write and read an int value") + @ParameterizedTest(name = "with {0}") @MethodSource("bufferProviders") - void shouldPutInt(final BufferProvider provider) { - ByteBuf buffer = provider.getBuffer(1024); + void shouldPutInt(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); try { buffer.putInt(42); buffer.flip(); @@ -43,12 +52,14 @@ void shouldPutInt(final BufferProvider provider) { } finally { buffer.release(); } + bufferProvider.assertAllUnavailable(); } - @ParameterizedTest + @DisplayName("Should write and read a long value") + @ParameterizedTest(name = "with {0}") @MethodSource("bufferProviders") - void shouldPutLong(final BufferProvider provider) { - ByteBuf buffer = provider.getBuffer(1024); + void shouldPutLong(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); try { buffer.putLong(42L); buffer.flip(); @@ -56,12 +67,14 @@ void shouldPutLong(final BufferProvider provider) { } finally { buffer.release(); } + bufferProvider.assertAllUnavailable(); } - @ParameterizedTest + @DisplayName("Should write and read a double value") + @ParameterizedTest(name = "with {0}") @MethodSource("bufferProviders") - void shouldPutDouble(final BufferProvider provider) { - ByteBuf buffer = provider.getBuffer(1024); + void shouldPutDouble(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); try { buffer.putDouble(42.0D); buffer.flip(); @@ -69,12 +82,14 @@ void shouldPutDouble(final BufferProvider provider) { } finally { buffer.release(); } + bufferProvider.assertAllUnavailable(); } - @ParameterizedTest + @DisplayName("Should write and read int values at specific indices") + @ParameterizedTest(name = "with {0}") @MethodSource("bufferProviders") - void shouldPutIntAtIndex(final BufferProvider provider) { - ByteBuf buffer = provider.getBuffer(1024); + void shouldPutIntAtIndex(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); try { buffer.putInt(0); buffer.putInt(0); @@ -97,5 +112,422 @@ void shouldPutIntAtIndex(final BufferProvider provider) { } finally { buffer.release(); } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should write and read a single byte") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldPutAByte(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 42); + buffer.flip(); + assertEquals(42, buffer.get()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should write and read multiple bytes in sequence") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldPutSeveralBytes(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 42); + buffer.put((byte) 43); + buffer.put((byte) 44); + buffer.flip(); + + assertEquals(42, buffer.get()); + assertEquals(43, buffer.get()); + assertEquals(44, buffer.get()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should write and read bytes at specific indices") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldPutBytesAtIndex(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 0); + buffer.put((byte) 0); + buffer.put((byte) 0); + buffer.put((byte) 0); + buffer.put((byte) 43); + buffer.put((byte) 44); + buffer.put(0, (byte) 22); + buffer.put(1, (byte) 23); + buffer.put(2, (byte) 24); + buffer.put(3, (byte) 25); + buffer.flip(); + + assertEquals(22, buffer.get()); + assertEquals(23, buffer.get()); + assertEquals(24, buffer.get()); + assertEquals(25, buffer.get()); + assertEquals(43, buffer.get()); + assertEquals(44, buffer.get()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Remaining should decrease as bytes are written") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void whenWritingRemainingIsTheNumberOfBytesThatCanBeWritten(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + assertEquals(1024, buffer.remaining()); + buffer.put((byte) 1); + assertEquals(1023, buffer.remaining()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("HasRemaining should be true while space is available and false when full") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void whenWritingHasRemainingShouldBeTrueIfThereIsStillRoomToWrite(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(2); + try { + assertTrue(buffer.hasRemaining()); + buffer.put((byte) 1); + assertTrue(buffer.hasRemaining()); + buffer.put((byte) 1); + assertFalse(buffer.hasRemaining()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("NIO buffer conversion should preserve capacity and limit") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldReturnNIOBufferWithTheSameCapacityAndLimit(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(36); + try { + ByteBuffer nioBuffer = buffer.asNIO(); + assertEquals(36, nioBuffer.limit()); + assertEquals(0, nioBuffer.position()); + assertEquals(36, nioBuffer.remaining()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("NIO buffer conversion should preserve contents") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldReturnNIOBufferWithTheSameContents(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 42); + buffer.put((byte) 43); + buffer.put((byte) 44); + buffer.put((byte) 45); + buffer.put((byte) 46); + buffer.put((byte) 47); + buffer.flip(); + + ByteBuffer nioBuffer = buffer.asNIO(); + assertEquals(6, nioBuffer.limit()); + assertEquals(0, nioBuffer.position()); + assertEquals(42, nioBuffer.get()); + assertEquals(43, nioBuffer.get()); + assertEquals(44, nioBuffer.get()); + assertEquals(45, nioBuffer.get()); + assertEquals(46, nioBuffer.get()); + assertEquals(47, nioBuffer.get()); + assertEquals(0, nioBuffer.remaining()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Reference counting should increment on retain and decrement on release") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldEnforceReferenceCounts(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + buffer.put((byte) 1); + assertEquals(1, buffer.getReferenceCount()); + + buffer.retain(); + buffer.put((byte) 1); + assertEquals(2, buffer.getReferenceCount()); + + buffer.release(); + buffer.put((byte) 1); + assertEquals(1, buffer.getReferenceCount()); + + buffer.release(); + assertEquals(0, buffer.getReferenceCount()); + + Assertions.assertThrows(Exception.class, () -> buffer.put((byte) 1)); + + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Position should track current write position") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldTrackPosition(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + assertEquals(0, buffer.position()); + buffer.putInt(42); + assertEquals(4, buffer.position()); + buffer.putLong(100L); + assertEquals(12, buffer.position()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Clear should reset position and limit") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldClearPositionAndLimit(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 1); + buffer.put((byte) 2); + buffer.put((byte) 3); + assertEquals(3, buffer.position()); + + buffer.clear(); + assertEquals(0, buffer.position()); + assertEquals(1024, buffer.limit()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Bulk put should write multiple bytes from array") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldPutBulkBytes(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + byte[] data = {10, 20, 30, 40, 50}; + buffer.put(data, 0, 5); + buffer.flip(); + + for (byte b : data) { + assertEquals(b, buffer.get()); + } + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Bulk get should read multiple bytes into array") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldGetBulkBytes(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + byte[] original = {10, 20, 30, 40, 50}; + buffer.put(original, 0, 5); + buffer.flip(); + + byte[] read = new byte[5]; + buffer.get(read); + + for (int i = 0; i < 5; i++) { + assertEquals(original[i], read[i]); + } + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Multiple retain should increase reference count correctly") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldHandleMultipleRetain(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + assertEquals(1, buffer.getReferenceCount()); + + buffer.retain(); + buffer.retain(); + buffer.retain(); + assertEquals(4, buffer.getReferenceCount()); + + buffer.release(); + assertEquals(3, buffer.getReferenceCount()); + buffer.release(); + assertEquals(2, buffer.getReferenceCount()); + buffer.release(); + assertEquals(1, buffer.getReferenceCount()); + buffer.release(); + assertEquals(0, buffer.getReferenceCount()); + + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Position should track current offset") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldSetPositionMethod(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + assertEquals(0, buffer.position()); + buffer.put((byte) 1); + assertEquals(1, buffer.position()); + buffer.put((byte) 2); + assertEquals(2, buffer.position()); + buffer.put((byte) 3); + assertEquals(3, buffer.position()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should handle absolute byte get at specific index") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldGetByteAtIndex(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 10); + buffer.put((byte) 20); + buffer.put((byte) 30); + + assertEquals(10, buffer.get(0)); + assertEquals(20, buffer.get(1)); + assertEquals(30, buffer.get(2)); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should handle absolute int get at specific index") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldGetIntAtIndex(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.putInt(0, 12345); + buffer.putInt(4, 67890); + + assertEquals(12345, buffer.getInt(0)); + assertEquals(67890, buffer.getInt(4)); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should handle absolute long get at specific index") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldGetLongAtIndex(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + long value1 = 1234567890L; + long value2 = 9876543210L; + + buffer.putLong(value1); + buffer.putLong(value2); + buffer.position(0); + + assertEquals(value1, buffer.getLong(0)); + assertEquals(value2, buffer.getLong(8)); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should handle absolute double get at specific index") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldGetDoubleAtIndex(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + double value1 = 123.456; + double value2 = 789.012; + + buffer.putDouble(value1); + buffer.putDouble(value2); + buffer.position(0); + + assertEquals(value1, buffer.getDouble(0)); + assertEquals(value2, buffer.getDouble(8)); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Should handle putLong method") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldPutLongRelative(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.putLong(1234567890L); + buffer.putLong(9876543210L); + buffer.flip(); + + assertEquals(1234567890L, buffer.getLong()); + assertEquals(9876543210L, buffer.getLong()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Retain should return the buffer for chaining") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldReturnBufferFromRetain(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + ByteBuf retained = buffer.retain(); + assertEquals(2, buffer.getReferenceCount()); + assertNotNull(retained); + buffer.release(); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); + } + + @DisplayName("Flip should return the buffer for chaining") + @ParameterizedTest(name = "with {0}") + @MethodSource("bufferProviders") + void shouldReturnBufferFromFlip(final String description, final TrackingBufferProvider bufferProvider) { + ByteBuf buffer = bufferProvider.getBuffer(1024); + try { + buffer.put((byte) 1); + ByteBuf flipped = buffer.flip(); + assertNotNull(flipped); + assertEquals(1, buffer.get()); + } finally { + buffer.release(); + } + bufferProvider.assertAllUnavailable(); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java index 30a70042578..3191f909266 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; import com.mongodb.internal.connection.netty.NettyByteBuf; +import org.bson.BsonBinaryWriter; import org.bson.ByteBuf; import org.bson.ByteBufNIO; import org.junit.jupiter.api.DisplayName; @@ -41,23 +42,56 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static com.mongodb.internal.connection.TestBufferProviders.TrackingBufferProvider; - +@DisplayName("CompositeByteBuf") final class CompositeByteBufTest { + // Construction tests + @Test + @DisplayName("Construction: should throw IllegalArgumentException when buffers is null") @SuppressWarnings("ConstantConditions") - void shouldThrowIfBuffersIsNull() { + void constructorShouldThrowIfBuffersIsNull() { assertThrows(IllegalArgumentException.class, () -> new CompositeByteBuf((List) null)); } @Test - void shouldThrowIfBuffersIsEmpty() { + @DisplayName("Construction: should throw IllegalArgumentException when buffers is empty") + void constructorShouldThrowIfBuffersIsEmpty() { assertThrows(IllegalArgumentException.class, () -> new CompositeByteBuf(emptyList())); } - @DisplayName("referenceCount should be maintained") - @ParameterizedTest + @Test + @DisplayName("Construction: should calculate capacity as sum of all buffer limits") + void constructorShouldCalculateCapacityAsSumOfBufferLimits() { + assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).capacity()); + assertEquals(6, new CompositeByteBuf(asList( + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) + )).capacity()); + } + + @Test + @DisplayName("Construction: should initialize position to zero") + void constructorShouldInitializePositionToZero() { + assertEquals(0, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).position()); + } + + @Test + @DisplayName("Construction: should initialize limit as sum of all buffer limits") + void constructorShouldInitializeLimitAsSumOfBufferLimits() { + assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).limit()); + assertEquals(6, new CompositeByteBuf(asList( + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) + )).limit()); + } + + // Reference counting tests + + @DisplayName("Reference counting: should maintain reference count correctly") + @ParameterizedTest(name = "with {0}") @MethodSource("getBuffers") void referenceCountShouldBeMaintained(final List buffers) { CompositeByteBuf buf = new CompositeByteBuf(buffers); @@ -76,108 +110,198 @@ void referenceCountShouldBeMaintained(final List buffers) { assertThrows(IllegalStateException.class, buf::retain); } - private static Stream getBuffers() { - return Stream.of( - Arguments.of(Named.of("ByteBufNIO", - asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))))), - Arguments.of(Named.of("NettyByteBuf", - asList(new NettyByteBuf(copiedBuffer(new byte[]{1, 2, 3, 4})), - new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))), - Arguments.of(Named.of("Mixed NIO and NettyByteBuf", - asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))) - ); + @ParameterizedTest(name = "with {0}") + @DisplayName("Reference counting: should release underlying buffers when reference count reaches zero") + @MethodSource("bufferProviders") + void releaseShouldReleaseUnderlyingBuffers(final String description, final TrackingBufferProvider bufferProvider) { + List buffers = asList(bufferProvider.getBuffer(1), bufferProvider.getBuffer(1)); + CompositeByteBuf compositeByteBuf = new CompositeByteBuf(buffers); + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() > 0)); + bufferProvider.assertAllAvailable(); + + compositeByteBuf.release(); + buffers.forEach(ByteBuf::release); + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0)); + bufferProvider.assertAllUnavailable(); + } + + @ParameterizedTest(name = "with {0}") + @DisplayName("Reference counting: duplicate should have independent reference count from original") + @MethodSource("bufferProviders") + void duplicateShouldHaveIndependentReferenceCount(final String description, final TrackingBufferProvider bufferProvider) { + List buffers = asList(bufferProvider.getBuffer(1), bufferProvider.getBuffer(1)); + CompositeByteBuf compositeBuffer = new CompositeByteBuf(buffers); + assertEquals(1, compositeBuffer.getReferenceCount()); + + ByteBuf compositeBufferDuplicate = compositeBuffer.duplicate(); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + assertEquals(1, compositeBuffer.getReferenceCount()); + + compositeBuffer.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + + compositeBufferDuplicate.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(0, compositeBufferDuplicate.getReferenceCount()); + + bufferProvider.assertAllAvailable(); + buffers.forEach(ByteBuf::release); + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0)); + bufferProvider.assertAllUnavailable(); + } + + @ParameterizedTest(name = "with {0}") + @DisplayName("Reference counting: should work correctly with BsonBinaryWriter") + @MethodSource("bufferProviders") + void shouldWorkCorrectlyWithBsonBinaryWriter(final String description, final TrackingBufferProvider bufferProvider) { + List buffers; + + try (ByteBufferBsonOutput bufferBsonOutput = new ByteBufferBsonOutput(bufferProvider)) { + try (BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(bufferBsonOutput)) { + bsonBinaryWriter.writeStartDocument(); + bsonBinaryWriter.writeName("k"); + bsonBinaryWriter.writeInt32(42); + bsonBinaryWriter.writeEndDocument(); + bsonBinaryWriter.flush(); + } + buffers = bufferBsonOutput.getByteBuffers(); + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() > 0)); + + CompositeByteBuf compositeBuffer = new CompositeByteBuf(buffers); + assertEquals(1, compositeBuffer.getReferenceCount()); + + ByteBuf compositeBufferDuplicate = compositeBuffer.duplicate(); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + assertEquals(1, compositeBuffer.getReferenceCount()); + + compositeBuffer.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + + compositeBufferDuplicate.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(0, compositeBufferDuplicate.getReferenceCount()); + + bufferProvider.assertAllAvailable(); + buffers.forEach(ByteBuf::release); + } + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0)); + bufferProvider.assertAllUnavailable(); } @Test + @DisplayName("Reference counting: should throw IllegalStateException when accessing released buffer") + void shouldThrowIllegalStateExceptionIfBufferIsClosed() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + buf.release(); + + assertThrows(IllegalStateException.class, buf::get); + } + + // Byte order tests + + @Test + @DisplayName("Byte order: should throw UnsupportedOperationException for BIG_ENDIAN byte order") void orderShouldThrowIfNotLittleEndian() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); assertThrows(UnsupportedOperationException.class, () -> buf.order(ByteOrder.BIG_ENDIAN)); } @Test + @DisplayName("Byte order: should accept LITTLE_ENDIAN byte order") void orderShouldReturnNormallyIfLittleEndian() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); assertDoesNotThrow(() -> buf.order(ByteOrder.LITTLE_ENDIAN)); } + // Position tests + @Test - void limitShouldBeSumOfLimitsOfBuffers() { - assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).limit()); + @DisplayName("Position: should set position when within valid range") + void positionShouldBeSetIfInRange() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - assertEquals(6, new CompositeByteBuf(asList( - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) - )).limit()); + for (int i = 0; i <= 3; i++) { + buf.position(i); + assertEquals(i, buf.position()); + } } @Test - void capacityShouldBeTheInitialLimit() { - assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).capacity()); - assertEquals(6, new CompositeByteBuf(asList( - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) - )).capacity()); + @DisplayName("Position: should throw IndexOutOfBoundsException when position is negative") + void positionShouldThrowForNegativeValue() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); + assertThrows(IndexOutOfBoundsException.class, () -> buf.position(-1)); } @Test - void positionShouldBeZero() { - assertEquals(0, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).position()); + @DisplayName("Position: should throw IndexOutOfBoundsException when position exceeds capacity") + void positionShouldThrowWhenExceedsCapacity() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); + assertThrows(IndexOutOfBoundsException.class, () -> buf.position(4)); } @Test - void positionShouldBeSetIfInRange() { + @DisplayName("Position: should throw IndexOutOfBoundsException when position exceeds limit") + void positionShouldThrowWhenExceedsLimit() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - buf.position(0); - assertEquals(0, buf.position()); - - buf.position(1); - assertEquals(1, buf.position()); - - buf.position(2); - assertEquals(2, buf.position()); - - buf.position(3); - assertEquals(3, buf.position()); + buf.limit(2); + assertThrows(IndexOutOfBoundsException.class, () -> buf.position(3)); } @Test - void positionShouldThrowIfOutOfRange() { - CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); + @DisplayName("Position: should update remaining bytes as position changes during reads") + void positionRemainingAndHasRemainingShouldUpdateAsBytesAreRead() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - assertThrows(IndexOutOfBoundsException.class, () -> buf.position(-1)); - assertThrows(IndexOutOfBoundsException.class, () -> buf.position(4)); + for (int i = 0; i < 4; i++) { + assertEquals(i, buf.position()); + assertEquals(4 - i, buf.remaining()); + assertTrue(buf.hasRemaining()); + buf.get(); + } - buf.limit(2); - assertThrows(IndexOutOfBoundsException.class, () -> buf.position(3)); + assertEquals(4, buf.position()); + assertEquals(0, buf.remaining()); + assertFalse(buf.hasRemaining()); } + // Limit tests + @Test + @DisplayName("Limit: should set limit when within valid range") void limitShouldBeSetIfInRange() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - buf.limit(0); - assertEquals(0, buf.limit()); - buf.limit(1); - assertEquals(1, buf.limit()); - - buf.limit(2); - assertEquals(2, buf.limit()); - - buf.limit(3); - assertEquals(3, buf.limit()); + for (int i = 0; i <= 3; i++) { + buf.limit(i); + assertEquals(i, buf.limit()); + } } @Test - void limitShouldThrowIfOutOfRange() { + @DisplayName("Limit: should throw IndexOutOfBoundsException when limit is negative") + void limitShouldThrowForNegativeValue() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - assertThrows(IndexOutOfBoundsException.class, () -> buf.limit(-1)); + } + + @Test + @DisplayName("Limit: should throw IndexOutOfBoundsException when limit exceeds capacity") + void limitShouldThrowWhenExceedsCapacity() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); assertThrows(IndexOutOfBoundsException.class, () -> buf.limit(4)); } + // Clear tests + @Test + @DisplayName("Clear: should reset position to zero and limit to capacity") void clearShouldResetPositionAndLimit() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); buf.limit(2); @@ -188,7 +312,10 @@ void clearShouldResetPositionAndLimit() { assertEquals(3, buf.limit()); } + // Duplicate tests + @Test + @DisplayName("Duplicate: should copy position and limit to duplicate") void duplicateShouldCopyAllProperties() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 1, 2, 3, 4, 1, 2})))); buf.limit(6); @@ -203,35 +330,32 @@ void duplicateShouldCopyAllProperties() { assertEquals(2, buf.position()); } + // Get byte tests + @Test - void positionRemainingAndHasRemainingShouldUpdateAsBytesAreRead() { + @DisplayName("Get byte: relative get should read byte and move position") + void relativeGetShouldReadByteAndMovePosition() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - assertEquals(0, buf.position()); - assertEquals(4, buf.remaining()); - assertTrue(buf.hasRemaining()); - buf.get(); + assertEquals(1, buf.get()); assertEquals(1, buf.position()); - assertEquals(3, buf.remaining()); - assertTrue(buf.hasRemaining()); - - buf.get(); + assertEquals(2, buf.get()); assertEquals(2, buf.position()); - assertEquals(2, buf.remaining()); - assertTrue(buf.hasRemaining()); + } - buf.get(); - assertEquals(3, buf.position()); - assertEquals(1, buf.remaining()); - assertTrue(buf.hasRemaining()); + @Test + @DisplayName("Get byte: should throw IndexOutOfBoundsException when reading past limit") + void getShouldThrowWhenReadingPastLimit() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + buf.position(4); - buf.get(); - assertEquals(4, buf.position()); - assertEquals(0, buf.remaining()); - assertFalse(buf.hasRemaining()); + assertThrows(IndexOutOfBoundsException.class, buf::get); } + // Get int tests + @Test + @DisplayName("Get int: absolute getInt should read little-endian integer and preserve position") void absoluteGetIntShouldReadLittleEndianIntegerAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -243,6 +367,7 @@ void absoluteGetIntShouldReadLittleEndianIntegerAndPreservePosition() { } @Test + @DisplayName("Get int: absolute getInt should read correctly when integer is split across buffers") void absoluteGetIntShouldReadLittleEndianIntegerWhenIntegerIsSplitAcrossBuffers() { ByteBufNIO byteBufferOne = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})); ByteBuf byteBufferTwo = new NettyByteBuf(wrappedBuffer(new byte[]{3, 4})).flip(); @@ -256,16 +381,30 @@ void absoluteGetIntShouldReadLittleEndianIntegerWhenIntegerIsSplitAcrossBuffers( } @Test + @DisplayName("Get int: relative getInt should read little-endian integer and move position") void relativeGetIntShouldReadLittleEndianIntegerAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); int i = buf.getInt(); + assertEquals(67305985, i); assertEquals(4, buf.position()); assertEquals(0, byteBuffer.position()); } @Test + @DisplayName("Get int: should throw IndexOutOfBoundsException when not enough bytes for int") + void getIntShouldThrowWhenNotEnoughBytesForInt() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + buf.position(1); + + assertThrows(IndexOutOfBoundsException.class, buf::getInt); + } + + // Get long tests + + @Test + @DisplayName("Get long: absolute getLong should read little-endian long and preserve position") void absoluteGetLongShouldReadLittleEndianLongAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -277,7 +416,8 @@ void absoluteGetLongShouldReadLittleEndianLongAndPreservePosition() { } @Test - void absoluteGetLongShouldReadLittleEndianLongWhenDoubleIsSplitAcrossBuffers() { + @DisplayName("Get long: absolute getLong should read correctly when long is split across multiple buffers") + void absoluteGetLongShouldReadLittleEndianLongWhenSplitAcrossBuffers() { ByteBuf byteBufferOne = new NettyByteBuf(wrappedBuffer(new byte[]{1, 2})).flip(); ByteBuf byteBufferTwo = new ByteBufNIO(ByteBuffer.wrap(new byte[]{3, 4})); ByteBuf byteBufferThree = new NettyByteBuf(wrappedBuffer(new byte[]{5, 6})).flip(); @@ -287,11 +427,10 @@ void absoluteGetLongShouldReadLittleEndianLongWhenDoubleIsSplitAcrossBuffers() { assertEquals(578437695752307201L, l); assertEquals(0, buf.position()); - assertEquals(0, byteBufferOne.position()); - assertEquals(0, byteBufferTwo.position()); } @Test + @DisplayName("Get long: relative getLong should read little-endian long and move position") void relativeGetLongShouldReadLittleEndianLongAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -302,7 +441,10 @@ void relativeGetLongShouldReadLittleEndianLongAndMovePosition() { assertEquals(0, byteBuffer.position()); } + // Get double tests + @Test + @DisplayName("Get double: absolute getDouble should read little-endian double and preserve position") void absoluteGetDoubleShouldReadLittleEndianDoubleAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -314,6 +456,7 @@ void absoluteGetDoubleShouldReadLittleEndianDoubleAndPreservePosition() { } @Test + @DisplayName("Get double: relative getDouble should read little-endian double and move position") void relativeGetDoubleShouldReadLittleEndianDoubleAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -324,7 +467,10 @@ void relativeGetDoubleShouldReadLittleEndianDoubleAndMovePosition() { assertEquals(0, byteBuffer.position()); } + // Bulk get tests + @Test + @DisplayName("Bulk get: absolute bulk get should read bytes and preserve position") void absoluteBulkGetShouldReadBytesAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -337,6 +483,7 @@ void absoluteBulkGetShouldReadBytesAndPreservePosition() { } @Test + @DisplayName("Bulk get: absolute bulk get should read bytes split across multiple buffers") void absoluteBulkGetShouldReadBytesWhenSplitAcrossBuffers() { ByteBufNIO byteBufferOne = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1})); ByteBufNIO byteBufferTwo = new ByteBufNIO(ByteBuffer.wrap(new byte[]{2, 3})); @@ -354,6 +501,7 @@ void absoluteBulkGetShouldReadBytesWhenSplitAcrossBuffers() { } @Test + @DisplayName("Bulk get: relative bulk get should read bytes and move position") void relativeBulkGetShouldReadBytesAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -372,6 +520,16 @@ void relativeBulkGetShouldReadBytesAndMovePosition() { } @Test + @DisplayName("Bulk get: should throw IndexOutOfBoundsException when bulk get exceeds remaining") + void bulkGetShouldThrowWhenBulkGetExceedsRemaining() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + assertThrows(IndexOutOfBoundsException.class, () -> buf.get(new byte[2], 1, 5)); + } + + // asNIO tests + + @Test + @DisplayName("asNIO: should get as NIO ByteBuffer with correct position and limit") void shouldGetAsNIOByteBuffer() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})))); buf.position(1).limit(5); @@ -386,6 +544,7 @@ void shouldGetAsNIOByteBuffer() { } @Test + @DisplayName("asNIO: should consolidate multiple buffers into single NIO ByteBuffer") void shouldGetAsNIOByteBufferWithMultipleBuffers() { CompositeByteBuf buf = new CompositeByteBuf(asList( new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})), @@ -403,25 +562,23 @@ void shouldGetAsNIOByteBufferWithMultipleBuffers() { assertArrayEquals(new byte[]{2, 3, 4, 5, 6}, bytes); } - @Test - void shouldThrowIndexOutOfBoundsExceptionIfReadingOutOfBounds() { - CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - buf.position(4); + // Test data providers - assertThrows(IndexOutOfBoundsException.class, buf::get); - buf.position(1); - - assertThrows(IndexOutOfBoundsException.class, buf::getInt); - buf.position(0); - - assertThrows(IndexOutOfBoundsException.class, () -> buf.get(new byte[2], 1, 2)); + static Stream getBuffers() { + return Stream.of( + Arguments.of(Named.of("ByteBufNIO", + asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))))), + Arguments.of(Named.of("NettyByteBuf", + asList(new NettyByteBuf(copiedBuffer(new byte[]{1, 2, 3, 4})), + new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))), + Arguments.of(Named.of("Mixed NIO and NettyByteBuf", + asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))) + ); } - @Test - void shouldThrowIllegalStateExceptionIfBufferIsClosed() { - CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - buf.release(); - - assertThrows(IllegalStateException.class, buf::get); + static Stream bufferProviders() { + return TestBufferProviders.bufferProviders(); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java b/driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java new file mode 100644 index 00000000000..26da7220ce5 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java @@ -0,0 +1,407 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.internal.connection.netty.NettyByteBuf; +import com.mongodb.lang.NonNull; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.bson.ByteBuf; +import org.bson.ByteBufNIO; +import org.junit.jupiter.params.provider.Arguments; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Shared buffer providers for testing ByteBuf implementations. + */ +public final class TestBufferProviders { + private TestBufferProviders() { + } + + + public static java.util.stream.Stream bufferProviders() { + TestBufferProviders.TrackingBufferProvider nioBufferProvider = TestBufferProviders.trackingBufferProvider(size -> new ByteBufNIO(ByteBuffer.allocate(size))); + PowerOfTwoBufferPool bufferPool = new PowerOfTwoBufferPool(1); + bufferPool.disablePruning(); + TestBufferProviders.TrackingBufferProvider pooledNioBufferProvider = TestBufferProviders.trackingBufferProvider(bufferPool); + TestBufferProviders.TrackingBufferProvider nettyBufferProvider = TestBufferProviders.trackingNettyBufferProvider(); + return Stream.of( + Arguments.of("NIO", nioBufferProvider), + Arguments.of("pooled NIO", pooledNioBufferProvider), + Arguments.of("Netty", nettyBufferProvider)); + } + + /** + * Creates a NettyBufferProvider that validates cleanup. + */ + public static BufferProvider nettyValidatingBufferProvider() { + return new NettyBufferProvider(); + } + + /** + * Creates a TrackingBufferProvider that tracks allocated buffers for validation. + */ + public static TrackingBufferProvider trackingNettyBufferProvider() { + return new TrackingBufferProvider(size -> new NettyByteBuf(UnpooledByteBufAllocator.DEFAULT.buffer(size, size))); + } + + /** + * Creates a TrackingBufferProvider that wraps any BufferProvider. + */ + public static TrackingBufferProvider trackingBufferProvider(final BufferProvider provider) { + return new TrackingBufferProvider(provider); + } + + /** + * Creates a TrackingBufferProvider that wraps a functional provider. + */ + public static TrackingBufferProvider trackingBufferProvider(final FunctionalBufferProvider provider) { + return new TrackingBufferProvider(provider); + } + + /** + * Functional interface for creating buffers of a given size. + */ + public interface FunctionalBufferProvider extends BufferProvider { + @Override + @NonNull + ByteBuf getBuffer(int size); + } + + /** + * A NettyBufferProvider that validates cleanup and prevents use-after-free. + */ + private static final class NettyBufferProvider implements BufferProvider { + private final ByteBufAllocator allocator; + + NettyBufferProvider() { + allocator = PooledByteBufAllocator.DEFAULT; + } + + @Override + public ByteBuf getBuffer(final int size) { + io.netty.buffer.ByteBuf nettyBuffer = allocator.directBuffer(size, size); + return new ValidatingNettyByteBuf(new NettyByteBuf(nettyBuffer)); + } + + private static final class ValidatingNettyByteBuf implements ByteBuf { + private final NettyByteBuf delegate; + private boolean released = false; + + ValidatingNettyByteBuf(final NettyByteBuf delegate) { + this.delegate = delegate; + } + + @Override + public int capacity() { + validateNotReleased(); + return delegate.capacity(); + } + + @Override + public ByteBuf put(final int index, final byte value) { + validateNotReleased(); + return delegate.put(index, value); + } + + @Override + public int remaining() { + validateNotReleased(); + return delegate.remaining(); + } + + @Override + public ByteBuf put(final byte[] src, final int offset, final int length) { + validateNotReleased(); + return delegate.put(src, offset, length); + } + + @Override + public boolean hasRemaining() { + validateNotReleased(); + return delegate.hasRemaining(); + } + + @Override + public ByteBuf put(final byte value) { + validateNotReleased(); + return delegate.put(value); + } + + @Override + public ByteBuf putInt(final int value) { + validateNotReleased(); + return delegate.putInt(value); + } + + @Override + public ByteBuf putInt(final int index, final int value) { + validateNotReleased(); + return delegate.putInt(index, value); + } + + @Override + public ByteBuf putDouble(final double value) { + validateNotReleased(); + return delegate.putDouble(value); + } + + @Override + public ByteBuf putLong(final long value) { + validateNotReleased(); + return delegate.putLong(value); + } + + @Override + public ByteBuf flip() { + validateNotReleased(); + return delegate.flip(); + } + + @Override + public byte[] array() { + validateNotReleased(); + return delegate.array(); + } + + @Override + public boolean isBackedByArray() { + validateNotReleased(); + return delegate.isBackedByArray(); + } + + @Override + public int arrayOffset() { + validateNotReleased(); + return delegate.arrayOffset(); + } + + @Override + public int limit() { + validateNotReleased(); + return delegate.limit(); + } + + @Override + public ByteBuf position(final int newPosition) { + validateNotReleased(); + return delegate.position(newPosition); + } + + @Override + public ByteBuf clear() { + validateNotReleased(); + return delegate.clear(); + } + + @Override + public ByteBuf order(final java.nio.ByteOrder byteOrder) { + validateNotReleased(); + return delegate.order(byteOrder); + } + + @Override + public byte get() { + validateNotReleased(); + return delegate.get(); + } + + @Override + public byte get(final int index) { + validateNotReleased(); + return delegate.get(index); + } + + @Override + public ByteBuf get(final byte[] bytes) { + validateNotReleased(); + return delegate.get(bytes); + } + + @Override + public ByteBuf get(final int index, final byte[] bytes) { + validateNotReleased(); + return delegate.get(index, bytes); + } + + @Override + public ByteBuf get(final byte[] bytes, final int offset, final int length) { + validateNotReleased(); + return delegate.get(bytes, offset, length); + } + + @Override + public ByteBuf get(final int index, final byte[] bytes, final int offset, final int length) { + validateNotReleased(); + return delegate.get(index, bytes, offset, length); + } + + @Override + public long getLong() { + validateNotReleased(); + return delegate.getLong(); + } + + @Override + public long getLong(final int index) { + validateNotReleased(); + return delegate.getLong(index); + } + + @Override + public double getDouble() { + validateNotReleased(); + return delegate.getDouble(); + } + + @Override + public double getDouble(final int index) { + validateNotReleased(); + return delegate.getDouble(index); + } + + @Override + public int getInt() { + validateNotReleased(); + return delegate.getInt(); + } + + @Override + public int getInt(final int index) { + validateNotReleased(); + return delegate.getInt(index); + } + + @Override + public int position() { + validateNotReleased(); + return delegate.position(); + } + + @Override + public ByteBuf limit(final int newLimit) { + validateNotReleased(); + return delegate.limit(newLimit); + } + + @Override + public ByteBuf asReadOnly() { + validateNotReleased(); + return delegate.asReadOnly(); + } + + @Override + public ByteBuf duplicate() { + validateNotReleased(); + return delegate.duplicate(); + } + + @Override + public java.nio.ByteBuffer asNIO() { + validateNotReleased(); + return delegate.asNIO(); + } + + @Override + public int getReferenceCount() { + validateNotReleased(); + return delegate.getReferenceCount(); + } + + @Override + public ByteBuf retain() { + validateNotReleased(); + return delegate.retain(); + } + + @Override + public void release() { + if (!released) { + released = true; + delegate.release(); + assertEquals(0, delegate.getReferenceCount(), "Buffer should have reference count 0 after release"); + } + } + + private void validateNotReleased() { + if (released) { + throw new IllegalStateException("Buffer has been released"); + } + } + } + } + + /** + * A BufferProvider that tracks allocated buffers for validation. + */ + public static final class TrackingBufferProvider implements BufferProvider { + private final BufferProvider decorated; + private final List tracked; + + public TrackingBufferProvider(final BufferProvider decorated) { + this.decorated = decorated; + tracked = new ArrayList<>(); + } + + @NonNull + @Override + public ByteBuf getBuffer(final int size) { + ByteBuf result = decorated.getBuffer(size); + tracked.add(result); + return result; + } + + /** + * Asserts that all tracked buffers are still available (reference count > 0). + */ + public void assertAllAvailable() { + for (ByteBuf buffer : tracked) { + assertTrue(buffer.getReferenceCount() > 0); + if (buffer instanceof ByteBufNIO) { + assertNotNull(buffer.asNIO()); + } else if (buffer instanceof NettyByteBuf) { + assertTrue(((NettyByteBuf) buffer).asByteBuf().refCnt() > 0); + } + } + } + + /** + * Asserts that all tracked buffers have been released (reference count = 0). + */ + public void assertAllUnavailable() { + for (ByteBuf buffer : tracked) { + assertEquals(0, buffer.getReferenceCount()); + if (buffer instanceof ByteBufNIO) { + assertNull(buffer.asNIO()); + } + if (buffer instanceof NettyByteBuf) { + assertEquals(0, ((NettyByteBuf) buffer).asByteBuf().refCnt()); + } + } + } + } +}