From fc908e2dcc85e58bfd281d44b168da352aefd813 Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Mon, 8 Oct 2018 14:14:09 -0700 Subject: netty: expose setting a local socket address --- .../java/io/grpc/netty/NettyChannelBuilder.java | 48 ++++++++++++++++++++-- .../java/io/grpc/netty/NettyClientTransport.java | 19 ++++++--- .../io/grpc/netty/NettyClientTransportTest.java | 17 ++++++-- 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 4989694e3..46beec533 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -26,6 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.Internal; import io.grpc.NameResolver; @@ -80,6 +81,7 @@ public final class NettyChannelBuilder private long keepAliveTimeoutNanos = DEFAULT_KEEPALIVE_TIMEOUT_NANOS; private boolean keepAliveWithoutCalls; private ProtocolNegotiatorFactory protocolNegotiatorFactory; + private LocalSocketPicker localSocketPicker; /** * Creates a new builder with the given server address. This factory method is primarily intended @@ -326,6 +328,41 @@ public final class NettyChannelBuilder return this; } + + /** + * If non-{@code null}, attempts to create connections bound to a local port. + */ + public NettyChannelBuilder localSocketPicker(@Nullable LocalSocketPicker localSocketPicker) { + this.localSocketPicker = localSocketPicker; + return this; + } + + /** + * This class is meant to be overriden with a custom implementation of + * {@link #createSocketAddress}. The default implementation is a no-op. + * + * @since 1.16.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4917") + public static class LocalSocketPicker { + + /** + * Called by gRPC to pick local socket to bind to. This may be called multiple times. + * Subclasses are expected to override this method. + * + * @param remoteAddress the remote address to connect to. + * @param attrs the Attributes present on the {@link io.grpc.EquivalentAddressGroup} associated + * with the address. + * @return a {@link SocketAddress} suitable for binding, or else {@code null}. + * @since 1.16.0 + */ + @Nullable + public SocketAddress createSocketAddress( + SocketAddress remoteAddress, @EquivalentAddressGroup.Attr Attributes attrs) { + return null; + } + } + @Override @CheckReturnValue @Internal @@ -348,7 +385,7 @@ public final class NettyChannelBuilder negotiator, channelType, channelOptions, eventLoopGroup, flowControlWindow, maxInboundMessageSize(), maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, - transportTracerFactory.create()); + transportTracerFactory.create(), localSocketPicker); } @Override @@ -457,6 +494,7 @@ public final class NettyChannelBuilder private final long keepAliveTimeoutNanos; private final boolean keepAliveWithoutCalls; private final TransportTracer transportTracer; + private final LocalSocketPicker localSocketPicker; private boolean closed; @@ -464,7 +502,7 @@ public final class NettyChannelBuilder Class channelType, Map, ?> channelOptions, EventLoopGroup group, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, - TransportTracer transportTracer) { + TransportTracer transportTracer, LocalSocketPicker localSocketPicker) { this.protocolNegotiator = protocolNegotiator; this.channelType = channelType; this.channelOptions = new HashMap, Object>(channelOptions); @@ -475,6 +513,8 @@ public final class NettyChannelBuilder this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.transportTracer = transportTracer; + this.localSocketPicker = + localSocketPicker != null ? localSocketPicker : new LocalSocketPicker(); usingSharedGroup = group == null; if (usingSharedGroup) { @@ -505,12 +545,14 @@ public final class NettyChannelBuilder keepAliveTimeNanosState.backoff(); } }; + NettyClientTransport transport = new NettyClientTransport( serverAddress, channelType, channelOptions, group, localNegotiator, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), - tooManyPingsRunnable, transportTracer, options.getEagAttributes()); + tooManyPingsRunnable, transportTracer, options.getEagAttributes(), + localSocketPicker); return transport; } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index 6141db2fe..2be71c20d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -40,6 +40,7 @@ import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager.ClientKeepAlivePinger; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; +import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -64,7 +65,7 @@ import javax.annotation.Nullable; class NettyClientTransport implements ConnectionClientTransport { private final InternalLogId logId = InternalLogId.allocate(getClass().getName()); private final Map, ?> channelOptions; - private final SocketAddress address; + private final SocketAddress remoteAddress; private final Class channelType; private final EventLoopGroup group; private final ProtocolNegotiator negotiator; @@ -91,6 +92,7 @@ class NettyClientTransport implements ConnectionClientTransport { /** Since not thread-safe, may only be used from event loop. */ private final TransportTracer transportTracer; private final Attributes eagAttributes; + private final LocalSocketPicker localSocketPicker; NettyClientTransport( SocketAddress address, Class channelType, @@ -98,9 +100,10 @@ class NettyClientTransport implements ConnectionClientTransport { ProtocolNegotiator negotiator, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, String authority, @Nullable String userAgent, - Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes) { + Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, + LocalSocketPicker localSocketPicker) { this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); - this.address = Preconditions.checkNotNull(address, "address"); + this.remoteAddress = Preconditions.checkNotNull(address, "address"); this.group = Preconditions.checkNotNull(group, "group"); this.channelType = Preconditions.checkNotNull(channelType, "channelType"); this.channelOptions = Preconditions.checkNotNull(channelOptions, "channelOptions"); @@ -117,6 +120,7 @@ class NettyClientTransport implements ConnectionClientTransport { Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes"); + this.localSocketPicker = Preconditions.checkNotNull(localSocketPicker, "localSocketPicker"); } @Override @@ -215,6 +219,11 @@ class NettyClientTransport implements ConnectionClientTransport { // so it is safe to pass the key-value pair to b.option(). b.option((ChannelOption) entry.getKey(), entry.getValue()); } + SocketAddress localAddress = + localSocketPicker.createSocketAddress(remoteAddress, eagAttributes); + if (localAddress != null) { + b.localAddress(localAddress); + } /** * We don't use a ChannelInitializer in the client bootstrap because its "initChannel" method @@ -263,7 +272,7 @@ class NettyClientTransport implements ConnectionClientTransport { } }); // Start the connection operation to the server. - channel.connect(address); + channel.connect(remoteAddress); if (keepAliveManager != null) { keepAliveManager.onTransportStarted(); @@ -305,7 +314,7 @@ class NettyClientTransport implements ConnectionClientTransport { public String toString() { return MoreObjects.toStringHelper(this) .add("logId", logId.getId()) - .add("address", address) + .add("remoteAddress", remoteAddress) .add("channel", channel) .toString(); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index bd3c1e3a2..7b28fac5c 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -61,6 +61,7 @@ import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; import io.grpc.internal.TransportTracer; import io.grpc.internal.testing.TestUtils; +import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelOption; import io.netty.channel.nio.NioEventLoopGroup; @@ -87,6 +88,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import javax.annotation.Nullable; import javax.net.ssl.SSLHandshakeException; import org.junit.After; import org.junit.Before; @@ -179,7 +181,7 @@ public class NettyClientTransportTest { address, NioSocketChannel.class, channelOptions, group, newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, null /* user agent */, - tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY); + tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker()); transports.add(transport); callMeMaybe(transport.start(clientTransportListener)); @@ -419,7 +421,7 @@ public class NettyClientTransportTest { address, CantConstructChannel.class, new HashMap, Object>(), group, newNegotiator(), DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, - null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY); + null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker()); transports.add(transport); // Should not throw @@ -602,7 +604,7 @@ public class NettyClientTransportTest { DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, keepAliveTimeNano, keepAliveTimeoutNano, false, authority, userAgent, tooManyPingsRunnable, - new TransportTracer(), eagAttributes); + new TransportTracer(), eagAttributes, new SocketPicker()); transports.add(transport); return transport; } @@ -835,4 +837,13 @@ public class NettyClientTransportTest { @Override public void close() {} } + + private static final class SocketPicker extends LocalSocketPicker { + + @Nullable + @Override + public SocketAddress createSocketAddress(SocketAddress remoteAddress, Attributes attrs) { + return null; + } + } } -- cgit v1.2.3