Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ See [`.agents/references/style-reference`](.agents/references/style-reference.md

- No `System.out.println` / `System.err.println` — use SLF4J
- No `e.printStackTrace()` — use proper error handling
- Prefer lambdas over SAM (Single Abstract Method) anonymous class instantiation
- Copyright header required: `Copyright 2008-present MongoDB, Inc.`
- Every public package must have a `package-info.java`

Expand Down
4 changes: 4 additions & 0 deletions driver-core/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ Largest and most complex module.
./gradlew :driver-core:generateMongoDriverVersion # If MongoDriverVersion is missing
```

## Important

- Async code MUST handle errors and they MUST be handled via callbacks or handlers.

## Notes

- Most extensive test suite — JUnit 5 + Spock + Mockito.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.mongodb.internal.connection;

import com.mongodb.MongoClientException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
Expand All @@ -37,6 +38,7 @@
import javax.net.ssl.SSLParameters;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
Expand All @@ -46,6 +48,7 @@
import java.nio.channels.SocketChannel;
import java.security.NoSuchAlgorithmException;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -209,35 +212,64 @@ private static class TlsChannelStream extends AsynchronousChannelStream {
@Override
public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler<Void> handler) {
isTrue("unopened", getChannel() == null);
SocketChannel socketChannel = null;
SelectorMonitor.SocketRegistration socketRegistration = null;
try {
SocketChannel socketChannel = SocketChannel.open();
socketChannel.configureBlocking(false);
// getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeoutException.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
List<InetSocketAddress> socketAddresses = getSocketAddresses(getServerAddress(), inetAddressResolver);
if (socketAddresses.isEmpty()) {
throw new MongoSocketException("No addresses resolved for " + getServerAddress(), getServerAddress());
}
InetSocketAddress socketAddress = socketAddresses.get(0);
SocketChannel openedSocketChannel = SocketChannel.open();
Comment thread
rozza marked this conversation as resolved.
socketChannel = openedSocketChannel;
openedSocketChannel.configureBlocking(false);

socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
if (getSettings().getReceiveBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
}
if (getSettings().getSendBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
}
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
socketChannel, () -> initializeTslChannel(handler, socketChannel));
openedSocketChannel.connect(socketAddress);
socketRegistration = new SelectorMonitor.SocketRegistration(
openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel));

if (connectTimeoutMs > 0) {
scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs);
}
selectorMonitor.register(socketRegistration);
} catch (IOException e) {
closeSocketChannel(socketChannel, socketRegistration, e);
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
closeSocketChannel(socketChannel, socketRegistration, t);
handler.failed(t);
}
}

private void closeSocketChannel(@Nullable final SocketChannel socketChannel,
@Nullable final SelectorMonitor.SocketRegistration socketRegistration,
final Throwable failure) {
if (socketRegistration != null) {
try {
socketRegistration.tryCancelPendingConnection();
} catch (Throwable t) {
failure.addSuppressed(t);
}
}
if (socketChannel != null) {
try {
socketChannel.close();
} catch (Throwable e) {
failure.addSuppressed(e);
}
}
}

private void scheduleTimeoutInterruption(final AsyncCompletionHandler<Void> handler,
final SelectorMonitor.SocketRegistration socketRegistration,
final int connectTimeoutMs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@
package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import com.mongodb.spi.dns.InetAddressResolver;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
Expand All @@ -37,6 +41,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.InterruptedByTimeoutException;
Expand All @@ -52,13 +57,17 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -68,6 +77,64 @@ class TlsChannelStreamFunctionalTest {
private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1";
private static final int UNREACHABLE_PORT = 65333;

@Test
void shouldNotOpenSocketChannelIfNameResolutionFails() {
//given
MongoSocketException resolverException = new MongoSocketException("Temporary failure in name resolution", new ServerAddress());
InetAddressResolver inetAddressResolver = host -> {
throw resolverException;
};

try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.build(), SSL_SETTINGS);
Stream stream = streamFactory.create(new ServerAddress());
@SuppressWarnings("unchecked")
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);

//when
stream.openAsync(createOperationContext(100), handler);

//then
verify(handler).failed(resolverException);
verify(handler, never()).completed(null);
socketChannelMockedStatic.verify(SocketChannel::open, never());
}
}

@Test
void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException {
//given
IOException connectException = new IOException("connect failed");
InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress());

try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open());
StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel);
Mockito.doThrow(connectException).when(socketChannel).connect(any());
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.build(), SSL_SETTINGS);
Stream stream = streamFactory.create(new ServerAddress());
@SuppressWarnings("unchecked")
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);
ArgumentCaptor<Throwable> failureCaptor = ArgumentCaptor.forClass(Throwable.class);

//when
stream.openAsync(createOperationContext(100), handler);

//then
verify(handler).failed(failureCaptor.capture());
MongoSocketOpenException actual = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue());
assertSame(connectException, actual.getCause());
verify(handler, never()).completed(null);
verify(socketChannel, atLeastOnce()).close();
}
Comment thread
rozza marked this conversation as resolved.
}

@ParameterizedTest
@ValueSource(ints = {500, 1000, 2000})
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {
Expand Down