diff --git a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java
index 67121ecbe3c..e3eb299a40e 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java
@@ -24,17 +24,55 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
import static java.lang.String.format;
import static org.bson.assertions.Assertions.isTrueArgument;
import static org.bson.assertions.Assertions.notNull;
-class CompositeByteBuf implements ByteBuf {
+/**
+ * A composite {@link ByteBuf} that provides a unified view over a list of component buffers.
+ *
+ *
ByteBuf Ownership and Reference Counting
+ * This class manages the lifecycle of its component buffers with the following rules:
+ *
+ * - Constructor: Takes buffers as input but does NOT take ownership. The buffers are made
+ * read-only via {@link ByteBuf#asReadOnly()}, which creates read only duplicates.
+ * The original buffers remain owned by the caller.
+ * - {@link #duplicate()}: Creates a new composite with independent position/limit but calls
+ * {@link ByteBuf#retain()} on each component, incrementing their reference counts. The duplicate
+ * owns these retained references and releases them when it is released.
+ * - {@link #retain()}: Increments the composite's reference count AND retains all component
+ * buffers. Each retain() call must be paired with a {@link #release()}.
+ * - {@link #release()}: Decrements the composite's reference count AND releases all component
+ * buffers. When the count reaches 0, subsequent access will throw {@link IllegalStateException}.
+ *
+ *
+ * Important: The composite's reference count is independent from its components'
+ * reference counts, but they are kept in sync via {@link #retain()} and {@link #release()} methods.
+ *
+ * This class is not part of the public API and may be removed or changed at any time
+ */
+final class CompositeByteBuf implements ByteBuf {
private final List components;
private final AtomicInteger referenceCount = new AtomicInteger(1);
private int position;
private int limit;
+ /**
+ * Creates a composite buffer from the given list of buffers.
+ *
+ * ByteBuf Ownership: This constructor does NOT take ownership of the input buffers.
+ * It calls {@link ByteBuf#asReadOnly()} on each buffer, which creates a shallow read-only view that
+ * shares the same underlying data and reference count as the original. The caller retains ownership
+ * of the original buffers and is responsible for their lifecycle.
+ *
+ * The composite starts with a reference count of 1. When {@link #release()} is called and the
+ * reference count reaches 0, it does NOT automatically release the original buffers - the caller
+ * must handle that separately.
+ *
+ * @param buffers the list of buffers to compose, must not be null or empty
+ */
CompositeByteBuf(final List buffers) {
notNull("buffers", buffers);
isTrueArgument("buffer list not empty", !buffers.isEmpty());
@@ -50,7 +88,9 @@ class CompositeByteBuf implements ByteBuf {
}
private CompositeByteBuf(final CompositeByteBuf from) {
- components = from.components;
+ components = from.components.stream().map(c ->
+ new Component(c.buffer.retain(), c.offset))
+ .collect(Collectors.toList());
position = from.position();
limit = from.limit();
}
@@ -277,6 +317,19 @@ public ByteBuf asReadOnly() {
throw new UnsupportedOperationException();
}
+ /**
+ * Creates a duplicate of this composite buffer with independent position and limit.
+ *
+ * ByteBuf Ownership: The duplicate calls {@link ByteBuf#retain()} on each
+ * component buffer, incrementing their reference counts. The duplicate owns these retained
+ * references and starts with its own reference count of 1. The caller is responsible
+ * for releasing the duplicate when done, which will release the component buffers.
+ *
+ * The duplicate shares the underlying buffer data with this composite but has independent
+ * reference counting and position/limit state.
+ *
+ * @return a new composite buffer that shares data with this one but has independent state
+ */
@Override
public ByteBuf duplicate() {
return new CompositeByteBuf(this);
@@ -300,21 +353,42 @@ public int getReferenceCount() {
return referenceCount.get();
}
+ /**
+ * Increments the reference count of this composite buffer and all component buffers.
+ *
+ * Important: This method retains both the composite's reference count and all
+ * component buffers. If the reference count is already 0, an {@link IllegalStateException} is thrown.
+ * Note that if an exception is thrown, the component buffers will have been retained before the
+ * exception occurs.
+ *
+ * @return this buffer
+ * @throws IllegalStateException if the reference count is already 0
+ */
@Override
public ByteBuf retain() {
if (referenceCount.incrementAndGet() == 1) {
referenceCount.decrementAndGet();
throw new IllegalStateException("Attempted to increment the reference count when it is already 0");
}
+ components.forEach(c -> c.buffer.retain());
return this;
}
+ /**
+ * Decrements the reference count of this composite buffer and all component buffers.
+ *
+ * Important: This method releases both the composite's reference count and all
+ * component buffers. All component buffers are released even if an exception occurs.
+ *
+ * @throws IllegalStateException if the reference count is already 0
+ */
@Override
public void release() {
if (referenceCount.decrementAndGet() < 0) {
referenceCount.incrementAndGet();
throw new IllegalStateException("Attempted to decrement the reference count below 0");
}
+ components.forEach(c -> c.buffer.release());
}
private Component findComponent(final int index) {
diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java
index 72235b46760..965ac9e951a 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java
@@ -20,12 +20,13 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.concurrent.atomic.AtomicInteger;
/**
* This class is not part of the public API and may be removed or changed at any time
*/
public final class NettyByteBuf implements ByteBuf {
-
+ private final AtomicInteger referenceCount = new AtomicInteger(1);
private io.netty.buffer.ByteBuf proxied;
private boolean isWriting = true;
@@ -271,17 +272,26 @@ public ByteBuffer asNIO() {
@Override
public int getReferenceCount() {
- return proxied.refCnt();
+ return referenceCount.get();
}
@Override
public ByteBuf retain() {
+ if (referenceCount.incrementAndGet() == 1) {
+ referenceCount.decrementAndGet();
+ throw new IllegalStateException("Attempted to increment the reference count when it is already 0");
+ }
proxied.retain();
return this;
}
@Override
public void release() {
+ int newRefCount = referenceCount.decrementAndGet();
+ if (newRefCount < 0) {
+ referenceCount.incrementAndGet();
+ throw new IllegalStateException("Attempted to decrement the reference count below 0");
+ }
proxied.release();
}
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy
deleted file mode 100644
index d052d6b23f1..00000000000
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufSpecification.groovy
+++ /dev/null
@@ -1,255 +0,0 @@
-/*
- * Copyright 2008-present MongoDB, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.mongodb.internal.connection
-
-
-import com.mongodb.internal.connection.netty.NettyByteBuf
-import io.netty.buffer.ByteBufAllocator
-import io.netty.buffer.PooledByteBufAllocator
-import org.bson.ByteBuf
-import spock.lang.Specification
-
-class ByteBufSpecification extends Specification {
- def 'should put a byte'() {
- given:
- def buffer = provider.getBuffer(1024)
-
- when:
- buffer.put((byte) 42)
- buffer.flip()
-
- then:
- buffer.get() == 42
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'should put several bytes'() {
- given:
- def buffer = provider.getBuffer(1024)
-
- when:
- buffer.with {
- put((byte) 42)
- put((byte) 43)
- put((byte) 44)
- flip()
- }
-
- then:
- buffer.get() == 42
- buffer.get() == 43
- buffer.get() == 44
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'should put bytes at index'() {
- given:
- def buffer = provider.getBuffer(1024)
-
- when:
- buffer.with {
- put((byte) 0)
- put((byte) 0)
- put((byte) 0)
- put((byte) 0)
- put((byte) 43)
- put((byte) 44)
- put(0, (byte) 22)
- put(1, (byte) 23)
- put(2, (byte) 24)
- put(3, (byte) 25)
- flip()
- }
-
- then:
- buffer.get() == 22
- buffer.get() == 23
- buffer.get() == 24
- buffer.get() == 25
- buffer.get() == 43
- buffer.get() == 44
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'when writing, remaining should be the number of bytes that can be written'() {
- when:
- def buffer = provider.getBuffer(1024)
-
- then:
- buffer.remaining() == 1024
-
- and:
- buffer.put((byte) 1)
-
- then:
- buffer.remaining() == 1023
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'when writing, hasRemaining should be true if there is still room to write'() {
- when:
- def buffer = provider.getBuffer(2)
-
- then:
- buffer.hasRemaining()
-
- and:
- buffer.put((byte) 1)
-
- then:
- buffer.hasRemaining()
-
- and:
- buffer.put((byte) 1)
-
- then:
- !buffer.hasRemaining()
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'should return NIO buffer with the same capacity and limit'() {
- given:
- def buffer = provider.getBuffer(36)
-
- when:
- def nioBuffer = buffer.asNIO()
-
- then:
- nioBuffer.limit() == 36
- nioBuffer.position() == 0
- nioBuffer.remaining() == 36
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'should return NIO buffer with the same contents'() {
- given:
- def buffer = provider.getBuffer(1024)
-
- buffer.with {
- put((byte) 42)
- put((byte) 43)
- put((byte) 44)
- put((byte) 45)
- put((byte) 46)
- put((byte) 47)
-
- flip()
- }
-
- when:
- def nioBuffer = buffer.asNIO()
-
- then:
- nioBuffer.limit() == 6
- nioBuffer.position() == 0
- nioBuffer.get() == 42
- nioBuffer.get() == 43
- nioBuffer.get() == 44
- nioBuffer.get() == 45
- nioBuffer.get() == 46
- nioBuffer.get() == 47
- nioBuffer.remaining() == 0
-
- cleanup:
- buffer.release()
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- def 'should enforce reference counts'() {
- when:
- def buffer = provider.getBuffer(1024)
- buffer.put((byte) 1)
-
- then:
- buffer.referenceCount == 1
-
- when:
- buffer.retain()
- buffer.put((byte) 1)
-
- then:
- buffer.referenceCount == 2
-
- when:
- buffer.release()
- buffer.put((byte) 1)
-
- then:
- buffer.referenceCount == 1
-
- when:
- buffer.release()
-
- then:
- buffer.referenceCount == 0
-
- when:
- buffer.put((byte) 1)
-
- then:
- thrown(Exception)
-
- where:
- provider << [new NettyBufferProvider(), new SimpleBufferProvider()]
- }
-
- static final class NettyBufferProvider implements BufferProvider {
- private final ByteBufAllocator allocator
-
- NettyBufferProvider() {
- allocator = PooledByteBufAllocator.DEFAULT
- }
-
- @Override
- ByteBuf getBuffer(final int size) {
- io.netty.buffer.ByteBuf buffer = allocator.directBuffer(size, size)
- new NettyByteBuf(buffer)
- }
- }
-}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java
index 722d7d62fa4..f8d71fae3e3 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufTest.java
@@ -18,24 +18,33 @@
import org.bson.ByteBuf;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
+import java.nio.ByteBuffer;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static com.mongodb.internal.connection.TestBufferProviders.TrackingBufferProvider;
class ByteBufTest {
- static Stream bufferProviders() {
- return Stream.of(new ByteBufSpecification.NettyBufferProvider(), new SimpleBufferProvider());
+ static Stream bufferProviders() {
+ return TestBufferProviders.bufferProviders();
}
- @ParameterizedTest
+ @DisplayName("Should write and read an int value")
+ @ParameterizedTest(name = "with {0}")
@MethodSource("bufferProviders")
- void shouldPutInt(final BufferProvider provider) {
- ByteBuf buffer = provider.getBuffer(1024);
+ void shouldPutInt(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
try {
buffer.putInt(42);
buffer.flip();
@@ -43,12 +52,14 @@ void shouldPutInt(final BufferProvider provider) {
} finally {
buffer.release();
}
+ bufferProvider.assertAllUnavailable();
}
- @ParameterizedTest
+ @DisplayName("Should write and read a long value")
+ @ParameterizedTest(name = "with {0}")
@MethodSource("bufferProviders")
- void shouldPutLong(final BufferProvider provider) {
- ByteBuf buffer = provider.getBuffer(1024);
+ void shouldPutLong(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
try {
buffer.putLong(42L);
buffer.flip();
@@ -56,12 +67,14 @@ void shouldPutLong(final BufferProvider provider) {
} finally {
buffer.release();
}
+ bufferProvider.assertAllUnavailable();
}
- @ParameterizedTest
+ @DisplayName("Should write and read a double value")
+ @ParameterizedTest(name = "with {0}")
@MethodSource("bufferProviders")
- void shouldPutDouble(final BufferProvider provider) {
- ByteBuf buffer = provider.getBuffer(1024);
+ void shouldPutDouble(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
try {
buffer.putDouble(42.0D);
buffer.flip();
@@ -69,12 +82,14 @@ void shouldPutDouble(final BufferProvider provider) {
} finally {
buffer.release();
}
+ bufferProvider.assertAllUnavailable();
}
- @ParameterizedTest
+ @DisplayName("Should write and read int values at specific indices")
+ @ParameterizedTest(name = "with {0}")
@MethodSource("bufferProviders")
- void shouldPutIntAtIndex(final BufferProvider provider) {
- ByteBuf buffer = provider.getBuffer(1024);
+ void shouldPutIntAtIndex(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
try {
buffer.putInt(0);
buffer.putInt(0);
@@ -97,5 +112,422 @@ void shouldPutIntAtIndex(final BufferProvider provider) {
} finally {
buffer.release();
}
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should write and read a single byte")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldPutAByte(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 42);
+ buffer.flip();
+ assertEquals(42, buffer.get());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should write and read multiple bytes in sequence")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldPutSeveralBytes(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 42);
+ buffer.put((byte) 43);
+ buffer.put((byte) 44);
+ buffer.flip();
+
+ assertEquals(42, buffer.get());
+ assertEquals(43, buffer.get());
+ assertEquals(44, buffer.get());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should write and read bytes at specific indices")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldPutBytesAtIndex(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 0);
+ buffer.put((byte) 0);
+ buffer.put((byte) 0);
+ buffer.put((byte) 0);
+ buffer.put((byte) 43);
+ buffer.put((byte) 44);
+ buffer.put(0, (byte) 22);
+ buffer.put(1, (byte) 23);
+ buffer.put(2, (byte) 24);
+ buffer.put(3, (byte) 25);
+ buffer.flip();
+
+ assertEquals(22, buffer.get());
+ assertEquals(23, buffer.get());
+ assertEquals(24, buffer.get());
+ assertEquals(25, buffer.get());
+ assertEquals(43, buffer.get());
+ assertEquals(44, buffer.get());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Remaining should decrease as bytes are written")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void whenWritingRemainingIsTheNumberOfBytesThatCanBeWritten(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ assertEquals(1024, buffer.remaining());
+ buffer.put((byte) 1);
+ assertEquals(1023, buffer.remaining());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("HasRemaining should be true while space is available and false when full")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void whenWritingHasRemainingShouldBeTrueIfThereIsStillRoomToWrite(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(2);
+ try {
+ assertTrue(buffer.hasRemaining());
+ buffer.put((byte) 1);
+ assertTrue(buffer.hasRemaining());
+ buffer.put((byte) 1);
+ assertFalse(buffer.hasRemaining());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("NIO buffer conversion should preserve capacity and limit")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldReturnNIOBufferWithTheSameCapacityAndLimit(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(36);
+ try {
+ ByteBuffer nioBuffer = buffer.asNIO();
+ assertEquals(36, nioBuffer.limit());
+ assertEquals(0, nioBuffer.position());
+ assertEquals(36, nioBuffer.remaining());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("NIO buffer conversion should preserve contents")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldReturnNIOBufferWithTheSameContents(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 42);
+ buffer.put((byte) 43);
+ buffer.put((byte) 44);
+ buffer.put((byte) 45);
+ buffer.put((byte) 46);
+ buffer.put((byte) 47);
+ buffer.flip();
+
+ ByteBuffer nioBuffer = buffer.asNIO();
+ assertEquals(6, nioBuffer.limit());
+ assertEquals(0, nioBuffer.position());
+ assertEquals(42, nioBuffer.get());
+ assertEquals(43, nioBuffer.get());
+ assertEquals(44, nioBuffer.get());
+ assertEquals(45, nioBuffer.get());
+ assertEquals(46, nioBuffer.get());
+ assertEquals(47, nioBuffer.get());
+ assertEquals(0, nioBuffer.remaining());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Reference counting should increment on retain and decrement on release")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldEnforceReferenceCounts(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ buffer.put((byte) 1);
+ assertEquals(1, buffer.getReferenceCount());
+
+ buffer.retain();
+ buffer.put((byte) 1);
+ assertEquals(2, buffer.getReferenceCount());
+
+ buffer.release();
+ buffer.put((byte) 1);
+ assertEquals(1, buffer.getReferenceCount());
+
+ buffer.release();
+ assertEquals(0, buffer.getReferenceCount());
+
+ Assertions.assertThrows(Exception.class, () -> buffer.put((byte) 1));
+
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Position should track current write position")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldTrackPosition(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ assertEquals(0, buffer.position());
+ buffer.putInt(42);
+ assertEquals(4, buffer.position());
+ buffer.putLong(100L);
+ assertEquals(12, buffer.position());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Clear should reset position and limit")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldClearPositionAndLimit(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 1);
+ buffer.put((byte) 2);
+ buffer.put((byte) 3);
+ assertEquals(3, buffer.position());
+
+ buffer.clear();
+ assertEquals(0, buffer.position());
+ assertEquals(1024, buffer.limit());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Bulk put should write multiple bytes from array")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldPutBulkBytes(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ byte[] data = {10, 20, 30, 40, 50};
+ buffer.put(data, 0, 5);
+ buffer.flip();
+
+ for (byte b : data) {
+ assertEquals(b, buffer.get());
+ }
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Bulk get should read multiple bytes into array")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldGetBulkBytes(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ byte[] original = {10, 20, 30, 40, 50};
+ buffer.put(original, 0, 5);
+ buffer.flip();
+
+ byte[] read = new byte[5];
+ buffer.get(read);
+
+ for (int i = 0; i < 5; i++) {
+ assertEquals(original[i], read[i]);
+ }
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Multiple retain should increase reference count correctly")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldHandleMultipleRetain(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ assertEquals(1, buffer.getReferenceCount());
+
+ buffer.retain();
+ buffer.retain();
+ buffer.retain();
+ assertEquals(4, buffer.getReferenceCount());
+
+ buffer.release();
+ assertEquals(3, buffer.getReferenceCount());
+ buffer.release();
+ assertEquals(2, buffer.getReferenceCount());
+ buffer.release();
+ assertEquals(1, buffer.getReferenceCount());
+ buffer.release();
+ assertEquals(0, buffer.getReferenceCount());
+
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Position should track current offset")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldSetPositionMethod(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ assertEquals(0, buffer.position());
+ buffer.put((byte) 1);
+ assertEquals(1, buffer.position());
+ buffer.put((byte) 2);
+ assertEquals(2, buffer.position());
+ buffer.put((byte) 3);
+ assertEquals(3, buffer.position());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should handle absolute byte get at specific index")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldGetByteAtIndex(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 10);
+ buffer.put((byte) 20);
+ buffer.put((byte) 30);
+
+ assertEquals(10, buffer.get(0));
+ assertEquals(20, buffer.get(1));
+ assertEquals(30, buffer.get(2));
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should handle absolute int get at specific index")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldGetIntAtIndex(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.putInt(0, 12345);
+ buffer.putInt(4, 67890);
+
+ assertEquals(12345, buffer.getInt(0));
+ assertEquals(67890, buffer.getInt(4));
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should handle absolute long get at specific index")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldGetLongAtIndex(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ long value1 = 1234567890L;
+ long value2 = 9876543210L;
+
+ buffer.putLong(value1);
+ buffer.putLong(value2);
+ buffer.position(0);
+
+ assertEquals(value1, buffer.getLong(0));
+ assertEquals(value2, buffer.getLong(8));
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should handle absolute double get at specific index")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldGetDoubleAtIndex(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ double value1 = 123.456;
+ double value2 = 789.012;
+
+ buffer.putDouble(value1);
+ buffer.putDouble(value2);
+ buffer.position(0);
+
+ assertEquals(value1, buffer.getDouble(0));
+ assertEquals(value2, buffer.getDouble(8));
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Should handle putLong method")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldPutLongRelative(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.putLong(1234567890L);
+ buffer.putLong(9876543210L);
+ buffer.flip();
+
+ assertEquals(1234567890L, buffer.getLong());
+ assertEquals(9876543210L, buffer.getLong());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Retain should return the buffer for chaining")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldReturnBufferFromRetain(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ ByteBuf retained = buffer.retain();
+ assertEquals(2, buffer.getReferenceCount());
+ assertNotNull(retained);
+ buffer.release();
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @DisplayName("Flip should return the buffer for chaining")
+ @ParameterizedTest(name = "with {0}")
+ @MethodSource("bufferProviders")
+ void shouldReturnBufferFromFlip(final String description, final TrackingBufferProvider bufferProvider) {
+ ByteBuf buffer = bufferProvider.getBuffer(1024);
+ try {
+ buffer.put((byte) 1);
+ ByteBuf flipped = buffer.flip();
+ assertNotNull(flipped);
+ assertEquals(1, buffer.get());
+ } finally {
+ buffer.release();
+ }
+ bufferProvider.assertAllUnavailable();
}
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java
index 30a70042578..3191f909266 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java
@@ -16,6 +16,7 @@
package com.mongodb.internal.connection;
import com.mongodb.internal.connection.netty.NettyByteBuf;
+import org.bson.BsonBinaryWriter;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.junit.jupiter.api.DisplayName;
@@ -41,23 +42,56 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static com.mongodb.internal.connection.TestBufferProviders.TrackingBufferProvider;
-
+@DisplayName("CompositeByteBuf")
final class CompositeByteBufTest {
+ // Construction tests
+
@Test
+ @DisplayName("Construction: should throw IllegalArgumentException when buffers is null")
@SuppressWarnings("ConstantConditions")
- void shouldThrowIfBuffersIsNull() {
+ void constructorShouldThrowIfBuffersIsNull() {
assertThrows(IllegalArgumentException.class, () -> new CompositeByteBuf((List) null));
}
@Test
- void shouldThrowIfBuffersIsEmpty() {
+ @DisplayName("Construction: should throw IllegalArgumentException when buffers is empty")
+ void constructorShouldThrowIfBuffersIsEmpty() {
assertThrows(IllegalArgumentException.class, () -> new CompositeByteBuf(emptyList()));
}
- @DisplayName("referenceCount should be maintained")
- @ParameterizedTest
+ @Test
+ @DisplayName("Construction: should calculate capacity as sum of all buffer limits")
+ void constructorShouldCalculateCapacityAsSumOfBufferLimits() {
+ assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).capacity());
+ assertEquals(6, new CompositeByteBuf(asList(
+ new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
+ new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2}))
+ )).capacity());
+ }
+
+ @Test
+ @DisplayName("Construction: should initialize position to zero")
+ void constructorShouldInitializePositionToZero() {
+ assertEquals(0, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).position());
+ }
+
+ @Test
+ @DisplayName("Construction: should initialize limit as sum of all buffer limits")
+ void constructorShouldInitializeLimitAsSumOfBufferLimits() {
+ assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).limit());
+ assertEquals(6, new CompositeByteBuf(asList(
+ new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
+ new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2}))
+ )).limit());
+ }
+
+ // Reference counting tests
+
+ @DisplayName("Reference counting: should maintain reference count correctly")
+ @ParameterizedTest(name = "with {0}")
@MethodSource("getBuffers")
void referenceCountShouldBeMaintained(final List buffers) {
CompositeByteBuf buf = new CompositeByteBuf(buffers);
@@ -76,108 +110,198 @@ void referenceCountShouldBeMaintained(final List buffers) {
assertThrows(IllegalStateException.class, buf::retain);
}
- private static Stream getBuffers() {
- return Stream.of(
- Arguments.of(Named.of("ByteBufNIO",
- asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
- new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))))),
- Arguments.of(Named.of("NettyByteBuf",
- asList(new NettyByteBuf(copiedBuffer(new byte[]{1, 2, 3, 4})),
- new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))),
- Arguments.of(Named.of("Mixed NIO and NettyByteBuf",
- asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
- new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4})))))
- );
+ @ParameterizedTest(name = "with {0}")
+ @DisplayName("Reference counting: should release underlying buffers when reference count reaches zero")
+ @MethodSource("bufferProviders")
+ void releaseShouldReleaseUnderlyingBuffers(final String description, final TrackingBufferProvider bufferProvider) {
+ List buffers = asList(bufferProvider.getBuffer(1), bufferProvider.getBuffer(1));
+ CompositeByteBuf compositeByteBuf = new CompositeByteBuf(buffers);
+
+ assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() > 0));
+ bufferProvider.assertAllAvailable();
+
+ compositeByteBuf.release();
+ buffers.forEach(ByteBuf::release);
+
+ assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0));
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @ParameterizedTest(name = "with {0}")
+ @DisplayName("Reference counting: duplicate should have independent reference count from original")
+ @MethodSource("bufferProviders")
+ void duplicateShouldHaveIndependentReferenceCount(final String description, final TrackingBufferProvider bufferProvider) {
+ List buffers = asList(bufferProvider.getBuffer(1), bufferProvider.getBuffer(1));
+ CompositeByteBuf compositeBuffer = new CompositeByteBuf(buffers);
+ assertEquals(1, compositeBuffer.getReferenceCount());
+
+ ByteBuf compositeBufferDuplicate = compositeBuffer.duplicate();
+ assertEquals(1, compositeBufferDuplicate.getReferenceCount());
+ assertEquals(1, compositeBuffer.getReferenceCount());
+
+ compositeBuffer.release();
+ assertEquals(0, compositeBuffer.getReferenceCount());
+ assertEquals(1, compositeBufferDuplicate.getReferenceCount());
+
+ compositeBufferDuplicate.release();
+ assertEquals(0, compositeBuffer.getReferenceCount());
+ assertEquals(0, compositeBufferDuplicate.getReferenceCount());
+
+ bufferProvider.assertAllAvailable();
+ buffers.forEach(ByteBuf::release);
+
+ assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0));
+ bufferProvider.assertAllUnavailable();
+ }
+
+ @ParameterizedTest(name = "with {0}")
+ @DisplayName("Reference counting: should work correctly with BsonBinaryWriter")
+ @MethodSource("bufferProviders")
+ void shouldWorkCorrectlyWithBsonBinaryWriter(final String description, final TrackingBufferProvider bufferProvider) {
+ List buffers;
+
+ try (ByteBufferBsonOutput bufferBsonOutput = new ByteBufferBsonOutput(bufferProvider)) {
+ try (BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(bufferBsonOutput)) {
+ bsonBinaryWriter.writeStartDocument();
+ bsonBinaryWriter.writeName("k");
+ bsonBinaryWriter.writeInt32(42);
+ bsonBinaryWriter.writeEndDocument();
+ bsonBinaryWriter.flush();
+ }
+ buffers = bufferBsonOutput.getByteBuffers();
+ assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() > 0));
+
+ CompositeByteBuf compositeBuffer = new CompositeByteBuf(buffers);
+ assertEquals(1, compositeBuffer.getReferenceCount());
+
+ ByteBuf compositeBufferDuplicate = compositeBuffer.duplicate();
+ assertEquals(1, compositeBufferDuplicate.getReferenceCount());
+ assertEquals(1, compositeBuffer.getReferenceCount());
+
+ compositeBuffer.release();
+ assertEquals(0, compositeBuffer.getReferenceCount());
+ assertEquals(1, compositeBufferDuplicate.getReferenceCount());
+
+ compositeBufferDuplicate.release();
+ assertEquals(0, compositeBuffer.getReferenceCount());
+ assertEquals(0, compositeBufferDuplicate.getReferenceCount());
+
+ bufferProvider.assertAllAvailable();
+ buffers.forEach(ByteBuf::release);
+ }
+
+ assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0));
+ bufferProvider.assertAllUnavailable();
}
@Test
+ @DisplayName("Reference counting: should throw IllegalStateException when accessing released buffer")
+ void shouldThrowIllegalStateExceptionIfBufferIsClosed() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
+ buf.release();
+
+ assertThrows(IllegalStateException.class, buf::get);
+ }
+
+ // Byte order tests
+
+ @Test
+ @DisplayName("Byte order: should throw UnsupportedOperationException for BIG_ENDIAN byte order")
void orderShouldThrowIfNotLittleEndian() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
assertThrows(UnsupportedOperationException.class, () -> buf.order(ByteOrder.BIG_ENDIAN));
}
@Test
+ @DisplayName("Byte order: should accept LITTLE_ENDIAN byte order")
void orderShouldReturnNormallyIfLittleEndian() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
assertDoesNotThrow(() -> buf.order(ByteOrder.LITTLE_ENDIAN));
}
+ // Position tests
+
@Test
- void limitShouldBeSumOfLimitsOfBuffers() {
- assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).limit());
+ @DisplayName("Position: should set position when within valid range")
+ void positionShouldBeSetIfInRange() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
- assertEquals(6, new CompositeByteBuf(asList(
- new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
- new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2}))
- )).limit());
+ for (int i = 0; i <= 3; i++) {
+ buf.position(i);
+ assertEquals(i, buf.position());
+ }
}
@Test
- void capacityShouldBeTheInitialLimit() {
- assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).capacity());
- assertEquals(6, new CompositeByteBuf(asList(
- new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
- new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2}))
- )).capacity());
+ @DisplayName("Position: should throw IndexOutOfBoundsException when position is negative")
+ void positionShouldThrowForNegativeValue() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
+ assertThrows(IndexOutOfBoundsException.class, () -> buf.position(-1));
}
@Test
- void positionShouldBeZero() {
- assertEquals(0, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).position());
+ @DisplayName("Position: should throw IndexOutOfBoundsException when position exceeds capacity")
+ void positionShouldThrowWhenExceedsCapacity() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
+ assertThrows(IndexOutOfBoundsException.class, () -> buf.position(4));
}
@Test
- void positionShouldBeSetIfInRange() {
+ @DisplayName("Position: should throw IndexOutOfBoundsException when position exceeds limit")
+ void positionShouldThrowWhenExceedsLimit() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
- buf.position(0);
- assertEquals(0, buf.position());
-
- buf.position(1);
- assertEquals(1, buf.position());
-
- buf.position(2);
- assertEquals(2, buf.position());
-
- buf.position(3);
- assertEquals(3, buf.position());
+ buf.limit(2);
+ assertThrows(IndexOutOfBoundsException.class, () -> buf.position(3));
}
@Test
- void positionShouldThrowIfOutOfRange() {
- CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
+ @DisplayName("Position: should update remaining bytes as position changes during reads")
+ void positionRemainingAndHasRemainingShouldUpdateAsBytesAreRead() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
- assertThrows(IndexOutOfBoundsException.class, () -> buf.position(-1));
- assertThrows(IndexOutOfBoundsException.class, () -> buf.position(4));
+ for (int i = 0; i < 4; i++) {
+ assertEquals(i, buf.position());
+ assertEquals(4 - i, buf.remaining());
+ assertTrue(buf.hasRemaining());
+ buf.get();
+ }
- buf.limit(2);
- assertThrows(IndexOutOfBoundsException.class, () -> buf.position(3));
+ assertEquals(4, buf.position());
+ assertEquals(0, buf.remaining());
+ assertFalse(buf.hasRemaining());
}
+ // Limit tests
+
@Test
+ @DisplayName("Limit: should set limit when within valid range")
void limitShouldBeSetIfInRange() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
- buf.limit(0);
- assertEquals(0, buf.limit());
- buf.limit(1);
- assertEquals(1, buf.limit());
-
- buf.limit(2);
- assertEquals(2, buf.limit());
-
- buf.limit(3);
- assertEquals(3, buf.limit());
+ for (int i = 0; i <= 3; i++) {
+ buf.limit(i);
+ assertEquals(i, buf.limit());
+ }
}
@Test
- void limitShouldThrowIfOutOfRange() {
+ @DisplayName("Limit: should throw IndexOutOfBoundsException when limit is negative")
+ void limitShouldThrowForNegativeValue() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
-
assertThrows(IndexOutOfBoundsException.class, () -> buf.limit(-1));
+ }
+
+ @Test
+ @DisplayName("Limit: should throw IndexOutOfBoundsException when limit exceeds capacity")
+ void limitShouldThrowWhenExceedsCapacity() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
assertThrows(IndexOutOfBoundsException.class, () -> buf.limit(4));
}
+ // Clear tests
+
@Test
+ @DisplayName("Clear: should reset position to zero and limit to capacity")
void clearShouldResetPositionAndLimit() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3}))));
buf.limit(2);
@@ -188,7 +312,10 @@ void clearShouldResetPositionAndLimit() {
assertEquals(3, buf.limit());
}
+ // Duplicate tests
+
@Test
+ @DisplayName("Duplicate: should copy position and limit to duplicate")
void duplicateShouldCopyAllProperties() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 1, 2, 3, 4, 1, 2}))));
buf.limit(6);
@@ -203,35 +330,32 @@ void duplicateShouldCopyAllProperties() {
assertEquals(2, buf.position());
}
+ // Get byte tests
+
@Test
- void positionRemainingAndHasRemainingShouldUpdateAsBytesAreRead() {
+ @DisplayName("Get byte: relative get should read byte and move position")
+ void relativeGetShouldReadByteAndMovePosition() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
- assertEquals(0, buf.position());
- assertEquals(4, buf.remaining());
- assertTrue(buf.hasRemaining());
- buf.get();
+ assertEquals(1, buf.get());
assertEquals(1, buf.position());
- assertEquals(3, buf.remaining());
- assertTrue(buf.hasRemaining());
-
- buf.get();
+ assertEquals(2, buf.get());
assertEquals(2, buf.position());
- assertEquals(2, buf.remaining());
- assertTrue(buf.hasRemaining());
+ }
- buf.get();
- assertEquals(3, buf.position());
- assertEquals(1, buf.remaining());
- assertTrue(buf.hasRemaining());
+ @Test
+ @DisplayName("Get byte: should throw IndexOutOfBoundsException when reading past limit")
+ void getShouldThrowWhenReadingPastLimit() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
+ buf.position(4);
- buf.get();
- assertEquals(4, buf.position());
- assertEquals(0, buf.remaining());
- assertFalse(buf.hasRemaining());
+ assertThrows(IndexOutOfBoundsException.class, buf::get);
}
+ // Get int tests
+
@Test
+ @DisplayName("Get int: absolute getInt should read little-endian integer and preserve position")
void absoluteGetIntShouldReadLittleEndianIntegerAndPreservePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -243,6 +367,7 @@ void absoluteGetIntShouldReadLittleEndianIntegerAndPreservePosition() {
}
@Test
+ @DisplayName("Get int: absolute getInt should read correctly when integer is split across buffers")
void absoluteGetIntShouldReadLittleEndianIntegerWhenIntegerIsSplitAcrossBuffers() {
ByteBufNIO byteBufferOne = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2}));
ByteBuf byteBufferTwo = new NettyByteBuf(wrappedBuffer(new byte[]{3, 4})).flip();
@@ -256,16 +381,30 @@ void absoluteGetIntShouldReadLittleEndianIntegerWhenIntegerIsSplitAcrossBuffers(
}
@Test
+ @DisplayName("Get int: relative getInt should read little-endian integer and move position")
void relativeGetIntShouldReadLittleEndianIntegerAndMovePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
int i = buf.getInt();
+
assertEquals(67305985, i);
assertEquals(4, buf.position());
assertEquals(0, byteBuffer.position());
}
@Test
+ @DisplayName("Get int: should throw IndexOutOfBoundsException when not enough bytes for int")
+ void getIntShouldThrowWhenNotEnoughBytesForInt() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
+ buf.position(1);
+
+ assertThrows(IndexOutOfBoundsException.class, buf::getInt);
+ }
+
+ // Get long tests
+
+ @Test
+ @DisplayName("Get long: absolute getLong should read little-endian long and preserve position")
void absoluteGetLongShouldReadLittleEndianLongAndPreservePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -277,7 +416,8 @@ void absoluteGetLongShouldReadLittleEndianLongAndPreservePosition() {
}
@Test
- void absoluteGetLongShouldReadLittleEndianLongWhenDoubleIsSplitAcrossBuffers() {
+ @DisplayName("Get long: absolute getLong should read correctly when long is split across multiple buffers")
+ void absoluteGetLongShouldReadLittleEndianLongWhenSplitAcrossBuffers() {
ByteBuf byteBufferOne = new NettyByteBuf(wrappedBuffer(new byte[]{1, 2})).flip();
ByteBuf byteBufferTwo = new ByteBufNIO(ByteBuffer.wrap(new byte[]{3, 4}));
ByteBuf byteBufferThree = new NettyByteBuf(wrappedBuffer(new byte[]{5, 6})).flip();
@@ -287,11 +427,10 @@ void absoluteGetLongShouldReadLittleEndianLongWhenDoubleIsSplitAcrossBuffers() {
assertEquals(578437695752307201L, l);
assertEquals(0, buf.position());
- assertEquals(0, byteBufferOne.position());
- assertEquals(0, byteBufferTwo.position());
}
@Test
+ @DisplayName("Get long: relative getLong should read little-endian long and move position")
void relativeGetLongShouldReadLittleEndianLongAndMovePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -302,7 +441,10 @@ void relativeGetLongShouldReadLittleEndianLongAndMovePosition() {
assertEquals(0, byteBuffer.position());
}
+ // Get double tests
+
@Test
+ @DisplayName("Get double: absolute getDouble should read little-endian double and preserve position")
void absoluteGetDoubleShouldReadLittleEndianDoubleAndPreservePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -314,6 +456,7 @@ void absoluteGetDoubleShouldReadLittleEndianDoubleAndPreservePosition() {
}
@Test
+ @DisplayName("Get double: relative getDouble should read little-endian double and move position")
void relativeGetDoubleShouldReadLittleEndianDoubleAndMovePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -324,7 +467,10 @@ void relativeGetDoubleShouldReadLittleEndianDoubleAndMovePosition() {
assertEquals(0, byteBuffer.position());
}
+ // Bulk get tests
+
@Test
+ @DisplayName("Bulk get: absolute bulk get should read bytes and preserve position")
void absoluteBulkGetShouldReadBytesAndPreservePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -337,6 +483,7 @@ void absoluteBulkGetShouldReadBytesAndPreservePosition() {
}
@Test
+ @DisplayName("Bulk get: absolute bulk get should read bytes split across multiple buffers")
void absoluteBulkGetShouldReadBytesWhenSplitAcrossBuffers() {
ByteBufNIO byteBufferOne = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1}));
ByteBufNIO byteBufferTwo = new ByteBufNIO(ByteBuffer.wrap(new byte[]{2, 3}));
@@ -354,6 +501,7 @@ void absoluteBulkGetShouldReadBytesWhenSplitAcrossBuffers() {
}
@Test
+ @DisplayName("Bulk get: relative bulk get should read bytes and move position")
void relativeBulkGetShouldReadBytesAndMovePosition() {
ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}));
CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer));
@@ -372,6 +520,16 @@ void relativeBulkGetShouldReadBytesAndMovePosition() {
}
@Test
+ @DisplayName("Bulk get: should throw IndexOutOfBoundsException when bulk get exceeds remaining")
+ void bulkGetShouldThrowWhenBulkGetExceedsRemaining() {
+ CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
+ assertThrows(IndexOutOfBoundsException.class, () -> buf.get(new byte[2], 1, 5));
+ }
+
+ // asNIO tests
+
+ @Test
+ @DisplayName("asNIO: should get as NIO ByteBuffer with correct position and limit")
void shouldGetAsNIOByteBuffer() {
CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}))));
buf.position(1).limit(5);
@@ -386,6 +544,7 @@ void shouldGetAsNIOByteBuffer() {
}
@Test
+ @DisplayName("asNIO: should consolidate multiple buffers into single NIO ByteBuffer")
void shouldGetAsNIOByteBufferWithMultipleBuffers() {
CompositeByteBuf buf = new CompositeByteBuf(asList(
new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})),
@@ -403,25 +562,23 @@ void shouldGetAsNIOByteBufferWithMultipleBuffers() {
assertArrayEquals(new byte[]{2, 3, 4, 5, 6}, bytes);
}
- @Test
- void shouldThrowIndexOutOfBoundsExceptionIfReadingOutOfBounds() {
- CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
- buf.position(4);
+ // Test data providers
- assertThrows(IndexOutOfBoundsException.class, buf::get);
- buf.position(1);
-
- assertThrows(IndexOutOfBoundsException.class, buf::getInt);
- buf.position(0);
-
- assertThrows(IndexOutOfBoundsException.class, () -> buf.get(new byte[2], 1, 2));
+ static Stream getBuffers() {
+ return Stream.of(
+ Arguments.of(Named.of("ByteBufNIO",
+ asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
+ new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))))),
+ Arguments.of(Named.of("NettyByteBuf",
+ asList(new NettyByteBuf(copiedBuffer(new byte[]{1, 2, 3, 4})),
+ new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))),
+ Arguments.of(Named.of("Mixed NIO and NettyByteBuf",
+ asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})),
+ new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4})))))
+ );
}
- @Test
- void shouldThrowIllegalStateExceptionIfBufferIsClosed() {
- CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))));
- buf.release();
-
- assertThrows(IllegalStateException.class, buf::get);
+ static Stream bufferProviders() {
+ return TestBufferProviders.bufferProviders();
}
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java b/driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java
new file mode 100644
index 00000000000..26da7220ce5
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/connection/TestBufferProviders.java
@@ -0,0 +1,407 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.connection;
+
+import com.mongodb.internal.connection.netty.NettyByteBuf;
+import com.mongodb.lang.NonNull;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import org.bson.ByteBuf;
+import org.bson.ByteBufNIO;
+import org.junit.jupiter.params.provider.Arguments;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Stream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * Shared buffer providers for testing ByteBuf implementations.
+ */
+public final class TestBufferProviders {
+ private TestBufferProviders() {
+ }
+
+
+ public static java.util.stream.Stream bufferProviders() {
+ TestBufferProviders.TrackingBufferProvider nioBufferProvider = TestBufferProviders.trackingBufferProvider(size -> new ByteBufNIO(ByteBuffer.allocate(size)));
+ PowerOfTwoBufferPool bufferPool = new PowerOfTwoBufferPool(1);
+ bufferPool.disablePruning();
+ TestBufferProviders.TrackingBufferProvider pooledNioBufferProvider = TestBufferProviders.trackingBufferProvider(bufferPool);
+ TestBufferProviders.TrackingBufferProvider nettyBufferProvider = TestBufferProviders.trackingNettyBufferProvider();
+ return Stream.of(
+ Arguments.of("NIO", nioBufferProvider),
+ Arguments.of("pooled NIO", pooledNioBufferProvider),
+ Arguments.of("Netty", nettyBufferProvider));
+ }
+
+ /**
+ * Creates a NettyBufferProvider that validates cleanup.
+ */
+ public static BufferProvider nettyValidatingBufferProvider() {
+ return new NettyBufferProvider();
+ }
+
+ /**
+ * Creates a TrackingBufferProvider that tracks allocated buffers for validation.
+ */
+ public static TrackingBufferProvider trackingNettyBufferProvider() {
+ return new TrackingBufferProvider(size -> new NettyByteBuf(UnpooledByteBufAllocator.DEFAULT.buffer(size, size)));
+ }
+
+ /**
+ * Creates a TrackingBufferProvider that wraps any BufferProvider.
+ */
+ public static TrackingBufferProvider trackingBufferProvider(final BufferProvider provider) {
+ return new TrackingBufferProvider(provider);
+ }
+
+ /**
+ * Creates a TrackingBufferProvider that wraps a functional provider.
+ */
+ public static TrackingBufferProvider trackingBufferProvider(final FunctionalBufferProvider provider) {
+ return new TrackingBufferProvider(provider);
+ }
+
+ /**
+ * Functional interface for creating buffers of a given size.
+ */
+ public interface FunctionalBufferProvider extends BufferProvider {
+ @Override
+ @NonNull
+ ByteBuf getBuffer(int size);
+ }
+
+ /**
+ * A NettyBufferProvider that validates cleanup and prevents use-after-free.
+ */
+ private static final class NettyBufferProvider implements BufferProvider {
+ private final ByteBufAllocator allocator;
+
+ NettyBufferProvider() {
+ allocator = PooledByteBufAllocator.DEFAULT;
+ }
+
+ @Override
+ public ByteBuf getBuffer(final int size) {
+ io.netty.buffer.ByteBuf nettyBuffer = allocator.directBuffer(size, size);
+ return new ValidatingNettyByteBuf(new NettyByteBuf(nettyBuffer));
+ }
+
+ private static final class ValidatingNettyByteBuf implements ByteBuf {
+ private final NettyByteBuf delegate;
+ private boolean released = false;
+
+ ValidatingNettyByteBuf(final NettyByteBuf delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public int capacity() {
+ validateNotReleased();
+ return delegate.capacity();
+ }
+
+ @Override
+ public ByteBuf put(final int index, final byte value) {
+ validateNotReleased();
+ return delegate.put(index, value);
+ }
+
+ @Override
+ public int remaining() {
+ validateNotReleased();
+ return delegate.remaining();
+ }
+
+ @Override
+ public ByteBuf put(final byte[] src, final int offset, final int length) {
+ validateNotReleased();
+ return delegate.put(src, offset, length);
+ }
+
+ @Override
+ public boolean hasRemaining() {
+ validateNotReleased();
+ return delegate.hasRemaining();
+ }
+
+ @Override
+ public ByteBuf put(final byte value) {
+ validateNotReleased();
+ return delegate.put(value);
+ }
+
+ @Override
+ public ByteBuf putInt(final int value) {
+ validateNotReleased();
+ return delegate.putInt(value);
+ }
+
+ @Override
+ public ByteBuf putInt(final int index, final int value) {
+ validateNotReleased();
+ return delegate.putInt(index, value);
+ }
+
+ @Override
+ public ByteBuf putDouble(final double value) {
+ validateNotReleased();
+ return delegate.putDouble(value);
+ }
+
+ @Override
+ public ByteBuf putLong(final long value) {
+ validateNotReleased();
+ return delegate.putLong(value);
+ }
+
+ @Override
+ public ByteBuf flip() {
+ validateNotReleased();
+ return delegate.flip();
+ }
+
+ @Override
+ public byte[] array() {
+ validateNotReleased();
+ return delegate.array();
+ }
+
+ @Override
+ public boolean isBackedByArray() {
+ validateNotReleased();
+ return delegate.isBackedByArray();
+ }
+
+ @Override
+ public int arrayOffset() {
+ validateNotReleased();
+ return delegate.arrayOffset();
+ }
+
+ @Override
+ public int limit() {
+ validateNotReleased();
+ return delegate.limit();
+ }
+
+ @Override
+ public ByteBuf position(final int newPosition) {
+ validateNotReleased();
+ return delegate.position(newPosition);
+ }
+
+ @Override
+ public ByteBuf clear() {
+ validateNotReleased();
+ return delegate.clear();
+ }
+
+ @Override
+ public ByteBuf order(final java.nio.ByteOrder byteOrder) {
+ validateNotReleased();
+ return delegate.order(byteOrder);
+ }
+
+ @Override
+ public byte get() {
+ validateNotReleased();
+ return delegate.get();
+ }
+
+ @Override
+ public byte get(final int index) {
+ validateNotReleased();
+ return delegate.get(index);
+ }
+
+ @Override
+ public ByteBuf get(final byte[] bytes) {
+ validateNotReleased();
+ return delegate.get(bytes);
+ }
+
+ @Override
+ public ByteBuf get(final int index, final byte[] bytes) {
+ validateNotReleased();
+ return delegate.get(index, bytes);
+ }
+
+ @Override
+ public ByteBuf get(final byte[] bytes, final int offset, final int length) {
+ validateNotReleased();
+ return delegate.get(bytes, offset, length);
+ }
+
+ @Override
+ public ByteBuf get(final int index, final byte[] bytes, final int offset, final int length) {
+ validateNotReleased();
+ return delegate.get(index, bytes, offset, length);
+ }
+
+ @Override
+ public long getLong() {
+ validateNotReleased();
+ return delegate.getLong();
+ }
+
+ @Override
+ public long getLong(final int index) {
+ validateNotReleased();
+ return delegate.getLong(index);
+ }
+
+ @Override
+ public double getDouble() {
+ validateNotReleased();
+ return delegate.getDouble();
+ }
+
+ @Override
+ public double getDouble(final int index) {
+ validateNotReleased();
+ return delegate.getDouble(index);
+ }
+
+ @Override
+ public int getInt() {
+ validateNotReleased();
+ return delegate.getInt();
+ }
+
+ @Override
+ public int getInt(final int index) {
+ validateNotReleased();
+ return delegate.getInt(index);
+ }
+
+ @Override
+ public int position() {
+ validateNotReleased();
+ return delegate.position();
+ }
+
+ @Override
+ public ByteBuf limit(final int newLimit) {
+ validateNotReleased();
+ return delegate.limit(newLimit);
+ }
+
+ @Override
+ public ByteBuf asReadOnly() {
+ validateNotReleased();
+ return delegate.asReadOnly();
+ }
+
+ @Override
+ public ByteBuf duplicate() {
+ validateNotReleased();
+ return delegate.duplicate();
+ }
+
+ @Override
+ public java.nio.ByteBuffer asNIO() {
+ validateNotReleased();
+ return delegate.asNIO();
+ }
+
+ @Override
+ public int getReferenceCount() {
+ validateNotReleased();
+ return delegate.getReferenceCount();
+ }
+
+ @Override
+ public ByteBuf retain() {
+ validateNotReleased();
+ return delegate.retain();
+ }
+
+ @Override
+ public void release() {
+ if (!released) {
+ released = true;
+ delegate.release();
+ assertEquals(0, delegate.getReferenceCount(), "Buffer should have reference count 0 after release");
+ }
+ }
+
+ private void validateNotReleased() {
+ if (released) {
+ throw new IllegalStateException("Buffer has been released");
+ }
+ }
+ }
+ }
+
+ /**
+ * A BufferProvider that tracks allocated buffers for validation.
+ */
+ public static final class TrackingBufferProvider implements BufferProvider {
+ private final BufferProvider decorated;
+ private final List tracked;
+
+ public TrackingBufferProvider(final BufferProvider decorated) {
+ this.decorated = decorated;
+ tracked = new ArrayList<>();
+ }
+
+ @NonNull
+ @Override
+ public ByteBuf getBuffer(final int size) {
+ ByteBuf result = decorated.getBuffer(size);
+ tracked.add(result);
+ return result;
+ }
+
+ /**
+ * Asserts that all tracked buffers are still available (reference count > 0).
+ */
+ public void assertAllAvailable() {
+ for (ByteBuf buffer : tracked) {
+ assertTrue(buffer.getReferenceCount() > 0);
+ if (buffer instanceof ByteBufNIO) {
+ assertNotNull(buffer.asNIO());
+ } else if (buffer instanceof NettyByteBuf) {
+ assertTrue(((NettyByteBuf) buffer).asByteBuf().refCnt() > 0);
+ }
+ }
+ }
+
+ /**
+ * Asserts that all tracked buffers have been released (reference count = 0).
+ */
+ public void assertAllUnavailable() {
+ for (ByteBuf buffer : tracked) {
+ assertEquals(0, buffer.getReferenceCount());
+ if (buffer instanceof ByteBufNIO) {
+ assertNull(buffer.asNIO());
+ }
+ if (buffer instanceof NettyByteBuf) {
+ assertEquals(0, ((NettyByteBuf) buffer).asByteBuf().refCnt());
+ }
+ }
+ }
+ }
+}