diff options
author | sanjaypujare <sanjaypujare@users.noreply.github.com> | 2023-11-03 09:57:59 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-03 09:57:59 -0700 |
commit | 15fc70be2ab92271f0a74c9cd04a73eba66193aa (patch) | |
tree | a0ace47d2a4ecab963b67a3f225efa8e7f9bc245 | |
parent | 9888a54abd88d60a367cdd476be453591b9240e1 (diff) | |
download | grpc-grpc-java-15fc70be2ab92271f0a74c9cd04a73eba66193aa.tar.gz |
core, netty, okhttp: implement new logic for nameResolverFactory API in channelBuilder (#10590)
* core, netty, okhttp: implement new logic for nameResolverFactory API in channelBuilder
fix ManagedChannelImpl to use NameResolverRegistry instead of NameResolverFactory
fix the ManagedChannelImplBuilder and remove nameResolverFactory
* Integrate target parsing and NameResolverProvider searching
Actually creating the name resolver is now delayed to the end of
ManagedChannelImpl.getNameResolver; we don't want to call into the name
resolver to determine if we should use the name resolver.
Added getDefaultScheme() to NameResolverRegistry to avoid needing
NameResolver.Factory.
---------
Co-authored-by: Eric Anderson <ejona@google.com>
42 files changed, 796 insertions, 118 deletions
diff --git a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java index 169afe307..8e8d175b7 100644 --- a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java +++ b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java @@ -27,6 +27,7 @@ import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.concurrent.DefaultThreadFactory; +import java.net.InetSocketAddress; import java.util.concurrent.TimeUnit; /** @@ -57,7 +58,7 @@ final class HandshakerServiceChannel { EventLoopGroup eventGroup = new NioEventLoopGroup(1, new DefaultThreadFactory("handshaker pool", true)); ManagedChannel channel = NettyChannelBuilder.forTarget(target) - .channelType(NioSocketChannel.class) + .channelType(NioSocketChannel.class, InetSocketAddress.class) .directExecutor() .eventLoopGroup(eventGroup) .usePlaintext() diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index 04bdc6b0d..31f874b80 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -161,13 +161,13 @@ public final class ManagedChannelRegistry { NameResolverProvider nameResolverProvider = null; try { URI uri = new URI(target); - nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme()); + nameResolverProvider = nameResolverRegistry.getProviderForScheme(uri.getScheme()); } catch (URISyntaxException ignore) { // bad URI found, just ignore and continue } if (nameResolverProvider == null) { - nameResolverProvider = nameResolverRegistry.providers().get( - nameResolverRegistry.asFactory().getDefaultScheme()); + nameResolverProvider = nameResolverRegistry.getProviderForScheme( + nameResolverRegistry.getDefaultScheme()); } Collection<Class<? extends SocketAddress>> nameResolverSocketAddressTypes = (nameResolverProvider != null) diff --git a/api/src/main/java/io/grpc/NameResolverProvider.java b/api/src/main/java/io/grpc/NameResolverProvider.java index 13cd750c3..70e22e366 100644 --- a/api/src/main/java/io/grpc/NameResolverProvider.java +++ b/api/src/main/java/io/grpc/NameResolverProvider.java @@ -75,7 +75,7 @@ public abstract class NameResolverProvider extends NameResolver.Factory { * * @return the {@link SocketAddress} types this provider's name-resolver is capable of producing. */ - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } } diff --git a/api/src/main/java/io/grpc/NameResolverRegistry.java b/api/src/main/java/io/grpc/NameResolverRegistry.java index 37dd92832..23eec23fd 100644 --- a/api/src/main/java/io/grpc/NameResolverRegistry.java +++ b/api/src/main/java/io/grpc/NameResolverRegistry.java @@ -58,6 +58,16 @@ public final class NameResolverRegistry { @GuardedBy("this") private ImmutableMap<String, NameResolverProvider> effectiveProviders = ImmutableMap.of(); + public synchronized String getDefaultScheme() { + return defaultScheme; + } + + public NameResolverProvider getProviderForScheme(String scheme) { + if (scheme == null) { + return null; + } + return providers().get(scheme.toLowerCase(Locale.US)); + } /** * Register a provider. @@ -163,19 +173,13 @@ public final class NameResolverRegistry { @Override @Nullable public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { - String scheme = targetUri.getScheme(); - if (scheme == null) { - return null; - } - NameResolverProvider provider = providers().get(scheme.toLowerCase(Locale.US)); + NameResolverProvider provider = getProviderForScheme(targetUri.getScheme()); return provider == null ? null : provider.newNameResolver(targetUri, args); } @Override public String getDefaultScheme() { - synchronized (NameResolverRegistry.this) { - return defaultScheme; - } + return NameResolverRegistry.this.getDefaultScheme(); } } diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 4a6dfa49c..30de2477d 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -173,13 +173,13 @@ public class ManagedChannelRegistryTest { nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(SocketAddress1.class); } }); nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") { @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { fail("Should not be called"); throw new AssertionError(); } @@ -234,7 +234,7 @@ public class ManagedChannelRegistryTest { nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class); } }); @@ -314,7 +314,7 @@ public class ManagedChannelRegistryTest { nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(SocketAddress1.class); } }); diff --git a/api/src/test/java/io/grpc/NameResolverRegistryTest.java b/api/src/test/java/io/grpc/NameResolverRegistryTest.java index 19ae09520..32067e976 100644 --- a/api/src/test/java/io/grpc/NameResolverRegistryTest.java +++ b/api/src/test/java/io/grpc/NameResolverRegistryTest.java @@ -203,12 +203,14 @@ public class NameResolverRegistryTest { public void baseProviders() { Map<String, NameResolverProvider> providers = NameResolverRegistry.getDefaultRegistry().providers(); - assertThat(providers).hasSize(1); + assertThat(providers).hasSize(2); // 2 name resolvers from grpclb and core, higher priority one is returned. assertThat(providers.get("dns").getClass().getName()) .isEqualTo("io.grpc.grpclb.SecretGrpclbNameResolverProvider$Provider"); assertThat(NameResolverRegistry.getDefaultRegistry().asFactory().getDefaultScheme()) .isEqualTo("dns"); + assertThat(providers.get("inprocess").getClass().getName()) + .isEqualTo("io.grpc.inprocess.InProcessNameResolverProvider"); } @Test diff --git a/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java b/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java index ea23fadee..d0de1571a 100644 --- a/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java +++ b/benchmarks/src/jmh/java/io/grpc/benchmarks/TransportBenchmark.java @@ -110,7 +110,7 @@ public class TransportBenchmark { .channelType(LocalServerChannel.class); channelBuilder = NettyChannelBuilder.forAddress(address) .eventLoopGroup(group) - .channelType(LocalChannel.class) + .channelType(LocalChannel.class, LocalAddress.class) .negotiationType(NegotiationType.PLAINTEXT); groupToShutdown = group; break; @@ -134,7 +134,7 @@ public class TransportBenchmark { .asSubclass(Channel.class); channelBuilder = NettyChannelBuilder.forAddress(address) .eventLoopGroup(group) - .channelType(channelClass) + .channelType(channelClass, InetSocketAddress.class) .negotiationType(NegotiationType.PLAINTEXT); groupToShutdown = group; break; diff --git a/benchmarks/src/jmh/java/io/grpc/benchmarks/netty/AbstractBenchmark.java b/benchmarks/src/jmh/java/io/grpc/benchmarks/netty/AbstractBenchmark.java index d68e66561..6d8a9ec8a 100644 --- a/benchmarks/src/jmh/java/io/grpc/benchmarks/netty/AbstractBenchmark.java +++ b/benchmarks/src/jmh/java/io/grpc/benchmarks/netty/AbstractBenchmark.java @@ -207,7 +207,7 @@ public abstract class AbstractBenchmark { serverBuilder = NettyServerBuilder.forAddress(address, serverCreds); serverBuilder.channelType(LocalServerChannel.class); channelBuilder = NettyChannelBuilder.forAddress(address); - channelBuilder.channelType(LocalChannel.class); + channelBuilder.channelType(LocalChannel.class, LocalAddress.class); } else { ServerSocket sock = new ServerSocket(); // Pick a port using an ephemeral socket. @@ -216,7 +216,8 @@ public abstract class AbstractBenchmark { sock.close(); serverBuilder = NettyServerBuilder.forAddress(address, serverCreds) .channelType(NioServerSocketChannel.class); - channelBuilder = NettyChannelBuilder.forAddress(address).channelType(NioSocketChannel.class); + channelBuilder = NettyChannelBuilder.forAddress(address).channelType(NioSocketChannel.class, + InetSocketAddress.class); } if (serverExecutor == ExecutorType.DIRECT) { diff --git a/benchmarks/src/main/java/io/grpc/benchmarks/Utils.java b/benchmarks/src/main/java/io/grpc/benchmarks/Utils.java index 8087afbf4..c4ba99e16 100644 --- a/benchmarks/src/main/java/io/grpc/benchmarks/Utils.java +++ b/benchmarks/src/main/java/io/grpc/benchmarks/Utils.java @@ -130,21 +130,21 @@ public final class Utils { case NETTY_NIO: builder .eventLoopGroup(new NioEventLoopGroup(0, tf)) - .channelType(NioSocketChannel.class); + .channelType(NioSocketChannel.class, InetSocketAddress.class); break; case NETTY_EPOLL: // These classes only work on Linux. builder .eventLoopGroup(new EpollEventLoopGroup(0, tf)) - .channelType(EpollSocketChannel.class); + .channelType(EpollSocketChannel.class, InetSocketAddress.class); break; case NETTY_UNIX_DOMAIN_SOCKET: // These classes only work on Linux. builder .eventLoopGroup(new EpollEventLoopGroup(0, tf)) - .channelType(EpollDomainSocketChannel.class); + .channelType(EpollDomainSocketChannel.class, DomainSocketAddress.class); break; default: diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java index c096fc194..83eabf407 100644 --- a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java @@ -39,6 +39,8 @@ import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -401,5 +403,10 @@ public final class BinderChannelBuilder executorService = scheduledExecutorPool.returnObject(executorService); offloadExecutor = offloadExecutorPool.returnObject(offloadExecutor); } + + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return Collections.singleton(AndroidComponentAddress.class); + } } } diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 1537d1c66..426318519 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -34,6 +34,7 @@ import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener; import java.net.SocketAddress; +import java.util.Collection; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; @@ -74,6 +75,11 @@ final class CallCredentialsApplyingTransportFactory implements ClientTransportFa delegate.close(); } + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return delegate.getSupportedSocketAddressTypes(); + } + private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport { private final ConnectionClientTransport delegate; private final String authority; diff --git a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java index 4d2ee92a0..d987f9d50 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java @@ -25,6 +25,7 @@ import io.grpc.ChannelLogger; import io.grpc.HttpConnectProxiedSocketAddress; import java.io.Closeable; import java.net.SocketAddress; +import java.util.Collection; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; @@ -74,6 +75,11 @@ public interface ClientTransportFactory extends Closeable { void close(); /** + * Returns the {@link SocketAddress} types this transport supports. + */ + Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes(); + + /** * Options passed to {@link #newClientTransport}. Although it is safe to save this object if * received, it is generally expected that the useful fields are copied and then the options * object is discarded. This allows using {@code final} for those fields as well as avoids diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java index 414a0ae88..c977fbb0c 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java @@ -84,7 +84,7 @@ public final class DnsNameResolverProvider extends NameResolverProvider { } @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 6d92b7851..a6e5e80f3 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -73,6 +73,7 @@ import io.grpc.MethodDescriptor; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; import io.grpc.Status; @@ -88,6 +89,7 @@ import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelect import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.RetriableStream.Throttle; import io.grpc.internal.RetryingNameResolver.ResolutionResultListener; +import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; @@ -160,7 +162,6 @@ final class ManagedChannelImpl extends ManagedChannel implements @Nullable private final String authorityOverride; private final NameResolverRegistry nameResolverRegistry; - private final NameResolver.Factory nameResolverFactory; private final NameResolver.Args nameResolverArgs; private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; private final ClientTransportFactory originalTransportFactory; @@ -376,7 +377,8 @@ final class ManagedChannelImpl extends ManagedChannel implements nameResolverStarted = false; if (channelIsActive) { nameResolver = getNameResolver( - target, authorityOverride, nameResolverFactory, nameResolverArgs); + target, authorityOverride, nameResolverRegistry, nameResolverArgs, + transportFactory.getSupportedSocketAddressTypes()); } else { nameResolver = null; } @@ -630,9 +632,9 @@ final class ManagedChannelImpl extends ManagedChannel implements .setOffloadExecutor(this.offloadExecutorHolder) .setOverrideAuthority(this.authorityOverride) .build(); - this.nameResolverFactory = builder.nameResolverFactory; this.nameResolver = getNameResolver( - target, authorityOverride, nameResolverFactory, nameResolverArgs); + target, authorityOverride, nameResolverRegistry, nameResolverArgs, + transportFactory.getSupportedSocketAddressTypes()); this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool"); this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool); this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext); @@ -704,54 +706,70 @@ final class ManagedChannelImpl extends ManagedChannel implements } private static NameResolver getNameResolver( - String target, NameResolver.Factory nameResolverFactory, NameResolver.Args nameResolverArgs) { + String target, NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, + Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) { // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending // "dns:///". + NameResolverProvider provider = null; URI targetUri = null; StringBuilder uriSyntaxErrors = new StringBuilder(); try { targetUri = new URI(target); - // For "localhost:8080" this would likely cause newNameResolver to return null, because - // "localhost" is parsed as the scheme. Will fall into the next branch and try - // "dns:///localhost:8080". } catch (URISyntaxException e) { // Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. uriSyntaxErrors.append(e.getMessage()); } if (targetUri != null) { - NameResolver resolver = nameResolverFactory.newNameResolver(targetUri, nameResolverArgs); - if (resolver != null) { - return resolver; - } - // "foo.googleapis.com:8080" cause resolver to be null, because "foo.googleapis.com" is an - // unmapped scheme. Just fall through and will try "dns:///foo.googleapis.com:8080" + // For "localhost:8080" this would likely cause provider to be null, because "localhost" is + // parsed as the scheme. Will hit the next case and try "dns:///localhost:8080". + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); } - // If we reached here, the targetUri couldn't be used. - if (!URI_PATTERN.matcher(target).matches()) { + if (provider == null && !URI_PATTERN.matcher(target).matches()) { // It doesn't look like a URI target. Maybe it's an authority string. Try with the default - // scheme from the factory. + // scheme from the registry. try { - targetUri = new URI(nameResolverFactory.getDefaultScheme(), "", "/" + target, null); + targetUri = new URI(nameResolverRegistry.getDefaultScheme(), "", "/" + target, null); } catch (URISyntaxException e) { // Should not be possible. throw new IllegalArgumentException(e); } - NameResolver resolver = nameResolverFactory.newNameResolver(targetUri, nameResolverArgs); - if (resolver != null) { - return resolver; + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null) { + throw new IllegalArgumentException(String.format( + "Could not find a NameResolverProvider for %s%s", + target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); + } + + if (channelTransportSocketAddressTypes != null) { + Collection<Class<? extends SocketAddress>> nameResolverSocketAddressTypes + = provider.getProducedSocketAddressTypes(); + if (!channelTransportSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { + throw new IllegalArgumentException(String.format( + "Address types of NameResolver '%s' for '%s' not supported by transport", + targetUri.getScheme(), target)); } } + + NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs); + if (resolver != null) { + return resolver; + } + throw new IllegalArgumentException(String.format( - "cannot find a NameResolver for %s%s", + "cannot create a NameResolver for %s%s", target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); } @VisibleForTesting static NameResolver getNameResolver( String target, @Nullable final String overrideAuthority, - NameResolver.Factory nameResolverFactory, NameResolver.Args nameResolverArgs) { - NameResolver resolver = getNameResolver(target, nameResolverFactory, nameResolverArgs); + NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, + Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) { + NameResolver resolver = getNameResolver(target, nameResolverRegistry, nameResolverArgs, + channelTransportSocketAddressTypes); // We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures. // TODO: After a transition period, all NameResolver implementations that need retry should use @@ -1625,7 +1643,8 @@ final class ManagedChannelImpl extends ManagedChannel implements channelCreds, callCredentials, transportFactoryBuilder, - new FixedPortProvider(nameResolverArgs.getDefaultPort())); + new FixedPortProvider(nameResolverArgs.getDefaultPort())) + .nameResolverRegistry(nameResolverRegistry); } @Override @@ -1637,8 +1656,7 @@ final class ManagedChannelImpl extends ManagedChannel implements checkState(!terminated, "Channel is terminated"); @SuppressWarnings("deprecation") - ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder() - .nameResolverFactory(nameResolverFactory); + ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(); return builder // TODO(zdapeng): executors should not outlive the parent channel. diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 7ef2f286a..bf96af6eb 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -35,6 +35,7 @@ import io.grpc.InternalGlobalInterceptors; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; import java.lang.reflect.InvocationTargetException; @@ -44,6 +45,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -133,10 +135,7 @@ public final class ManagedChannelImplBuilder ObjectPool<? extends Executor> offloadExecutorPool = DEFAULT_EXECUTOR_POOL; private final List<ClientInterceptor> interceptors = new ArrayList<>(); - final NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry(); - - // Access via getter, which may perform authority override as needed - NameResolver.Factory nameResolverFactory = nameResolverRegistry.asFactory(); + NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry(); final String target; @Nullable @@ -284,7 +283,7 @@ public final class ManagedChannelImplBuilder /** * Returns a target string for the SocketAddress. It is only used as a placeholder, because - * DirectAddressNameResolverFactory will not actually try to use it. However, it must be a valid + * DirectAddressNameResolverProvider will not actually try to use it. However, it must be a valid * URI. */ @VisibleForTesting @@ -327,7 +326,10 @@ public final class ManagedChannelImplBuilder this.clientTransportFactoryBuilder = Preconditions .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); this.directServerAddress = directServerAddress; - this.nameResolverFactory = new DirectAddressNameResolverFactory(directServerAddress, authority); + NameResolverRegistry reg = new NameResolverRegistry(); + reg.register(new DirectAddressNameResolverProvider(directServerAddress, + authority)); + this.nameResolverRegistry = reg; if (channelBuilderDefaultPortProvider != null) { this.channelBuilderDefaultPortProvider = channelBuilderDefaultPortProvider; @@ -379,13 +381,20 @@ public final class ManagedChannelImplBuilder "directServerAddress is set (%s), which forbids the use of NameResolverFactory", directServerAddress); if (resolverFactory != null) { - this.nameResolverFactory = resolverFactory; + NameResolverRegistry reg = new NameResolverRegistry(); + reg.register(new NameResolverFactoryToProviderFacade(resolverFactory)); + this.nameResolverRegistry = reg; } else { - this.nameResolverFactory = nameResolverRegistry.asFactory(); + this.nameResolverRegistry = NameResolverRegistry.getDefaultRegistry(); } return this; } + ManagedChannelImplBuilder nameResolverRegistry(NameResolverRegistry resolverRegistry) { + this.nameResolverRegistry = resolverRegistry; + return this; + } + @Override public ManagedChannelImplBuilder defaultLoadBalancingPolicy(String policy) { Preconditions.checkState(directServerAddress == null, @@ -728,13 +737,16 @@ public final class ManagedChannelImplBuilder return channelBuilderDefaultPortProvider.getDefaultPort(); } - private static class DirectAddressNameResolverFactory extends NameResolver.Factory { + private static class DirectAddressNameResolverProvider extends NameResolverProvider { final SocketAddress address; final String authority; + final Collection<Class<? extends SocketAddress>> producedSocketAddressTypes; - DirectAddressNameResolverFactory(SocketAddress address, String authority) { + DirectAddressNameResolverProvider(SocketAddress address, String authority) { this.address = address; this.authority = authority; + this.producedSocketAddressTypes + = Collections.singleton(address.getClass()); } @Override @@ -763,6 +775,21 @@ public final class ManagedChannelImplBuilder public String getDefaultScheme() { return DIRECT_ADDRESS_SCHEME; } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + return producedSocketAddressTypes; + } } /** diff --git a/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java b/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java new file mode 100644 index 000000000..31c20f6e4 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import io.grpc.NameResolver; +import io.grpc.NameResolver.Args; +import io.grpc.NameResolverProvider; +import java.net.URI; + +public class NameResolverFactoryToProviderFacade extends NameResolverProvider { + + private NameResolver.Factory factory; + + NameResolverFactoryToProviderFacade(NameResolver.Factory factory) { + this.factory = factory; + } + + @Override + public NameResolver newNameResolver(URI targetUri, Args args) { + return factory.newNameResolver(targetUri, args); + } + + @Override + public String getDefaultScheme() { + return factory.getDefaultScheme(); + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index dae8b9b37..67b80bf74 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,7 +40,9 @@ import io.grpc.InternalGlobalInterceptors; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import io.grpc.NameResolver; +import io.grpc.NameResolverRegistry; import io.grpc.StaticTestingClassLoader; +import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; @@ -196,25 +199,30 @@ public class ManagedChannelImplBuilderTest { } @Test - public void nameResolverFactory_default() { - assertNotNull(builder.nameResolverFactory); + public void nameResolverRegistry_default() { + assertNotNull(builder.nameResolverRegistry); } @Test @SuppressWarnings("deprecation") public void nameResolverFactory_normal() { NameResolver.Factory nameResolverFactory = mock(NameResolver.Factory.class); + doReturn("testscheme").when(nameResolverFactory).getDefaultScheme(); assertEquals(builder, builder.nameResolverFactory(nameResolverFactory)); - assertEquals(nameResolverFactory, builder.nameResolverFactory); + assertNotNull(builder.nameResolverRegistry); + assertEquals("testscheme", builder.nameResolverRegistry.asFactory().getDefaultScheme()); } @Test @SuppressWarnings("deprecation") public void nameResolverFactory_null() { - NameResolver.Factory defaultValue = builder.nameResolverFactory; - builder.nameResolverFactory(mock(NameResolver.Factory.class)); - assertEquals(builder, builder.nameResolverFactory(null)); - assertEquals(defaultValue, builder.nameResolverFactory); + NameResolverRegistry defaultValue = builder.nameResolverRegistry; + NameResolver.Factory nameResolverFactory = mock(NameResolver.Factory.class); + doReturn("testscheme").when(nameResolverFactory).getDefaultScheme(); + builder.nameResolverFactory(nameResolverFactory); + assertNotEquals(defaultValue, builder.nameResolverRegistry); + builder.nameResolverFactory(null); + assertEquals(defaultValue, builder.nameResolverRegistry); } @Test(expected = IllegalStateException.class) @@ -327,6 +335,8 @@ public class ManagedChannelImplBuilderTest { .thenReturn(clock.getScheduledExecutorService()); when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) .thenReturn(mockClientTransportFactory); + when(mockClientTransportFactory.getSupportedSocketAddressTypes()) + .thenReturn(Collections.singleton(InetSocketAddress.class)); builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); @@ -341,6 +351,8 @@ public class ManagedChannelImplBuilderTest { .thenReturn(clock.getScheduledExecutorService()); when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) .thenReturn(mockClientTransportFactory); + when(mockClientTransportFactory.getSupportedSocketAddressTypes()) + .thenReturn(Collections.singleton(InetSocketAddress.class)); builder = new ManagedChannelImplBuilder(DUMMY_TARGET, mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)) @@ -350,6 +362,41 @@ public class ManagedChannelImplBuilderTest { } @Test + public void transportDoesNotSupportAddressTypes() { + when(mockClientTransportFactory.getScheduledExecutorService()) + .thenReturn(clock.getScheduledExecutorService()); + when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) + .thenReturn(mockClientTransportFactory); + when(mockClientTransportFactory.getSupportedSocketAddressTypes()) + .thenReturn(Collections.singleton(InProcessSocketAddress.class)); + + builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, + mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); + try { + ManagedChannel unused = grpcCleanupRule.register(builder.build()); + fail("Should fail"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo( + "Address types of NameResolver 'dns' for 'valid:1234' not supported by transport"); + } + } + + @Test + public void transportAddressTypeCompatibilityCheckSkipped() { + when(mockClientTransportFactory.getScheduledExecutorService()) + .thenReturn(clock.getScheduledExecutorService()); + when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) + .thenReturn(mockClientTransportFactory); + when(mockClientTransportFactory.getSupportedSocketAddressTypes()) + .thenReturn(null); + + builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, + mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); + // should not fail + ManagedChannel unused = grpcCleanupRule.register(builder.build()); + } + + @Test public void overrideAuthority_default() { assertNull(builder.authorityOverride); } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index b63d53a6f..452e07191 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -24,17 +24,22 @@ import static org.mockito.Mockito.mock; import io.grpc.ChannelLogger; import io.grpc.NameResolver; +import io.grpc.NameResolver.Args; import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; import io.grpc.SynchronizationContext; +import io.grpc.inprocess.InProcessSocketAddress; import java.lang.Thread.UncaughtExceptionHandler; +import java.net.InetSocketAddress; import java.net.URI; +import java.util.Collections; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link ManagedChannelImpl#getNameResolver( - * String, String,NameResolver.Factory, NameResolver.Args)}. */ +/** Unit tests for ManagedChannelImpl#getNameResolver(). */ @RunWith(JUnit4.class) public class ManagedChannelImplGetNameResolverTest { private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder() @@ -68,9 +73,10 @@ public class ManagedChannelImplGetNameResolverTest { String target = "foo.googleapis.com:8080"; String overrideAuthority = "override.authority"; URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); - NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory(expectedUri.getScheme()); + NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); NameResolver nameResolver = ManagedChannelImpl.getNameResolver( - target, overrideAuthority, nameResolverFactory, NAMERESOLVER_ARGS); + target, overrideAuthority, nameResolverRegistry, NAMERESOLVER_ARGS, + Collections.singleton(InetSocketAddress.class)); assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); } @@ -116,10 +122,21 @@ public class ManagedChannelImplGetNameResolverTest { } @Test - public void validTargetNoResovler() { - NameResolver.Factory nameResolverFactory = new NameResolver.Factory() { + public void validTargetNoResolver() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + NameResolverProvider nameResolverProvider = new NameResolverProvider() { @Override - public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public NameResolver newNameResolver(URI targetUri, Args args) { return null; } @@ -128,41 +145,81 @@ public class ManagedChannelImplGetNameResolverTest { return "defaultscheme"; } }; + nameResolverRegistry.register(nameResolverProvider); + try { + ManagedChannelImpl.getNameResolver( + "foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, + Collections.singleton(InetSocketAddress.class)); + fail("Should fail"); + } catch (IllegalArgumentException e) { + // expected + } + } + + @Test + public void validTargetNoProvider() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); try { ManagedChannelImpl.getNameResolver( - "foo.googleapis.com:8080", null, nameResolverFactory, NAMERESOLVER_ARGS); + "foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, + Collections.singleton(InetSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { // expected } } + @Test + public void validTargetProviderAddrTypesNotSupported() { + NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); + try { + ManagedChannelImpl.getNameResolver( + "testscheme:///foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, + Collections.singleton(InProcessSocketAddress.class)); + fail("Should fail"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo( + "Address types of NameResolver 'testscheme' for " + + "'testscheme:///foo.googleapis.com:8080' not supported by transport"); + } + } + + private void testValidTarget(String target, String expectedUriString, URI expectedUri) { - NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory(expectedUri.getScheme()); + NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); FakeNameResolver nameResolver = (FakeNameResolver) ((RetryingNameResolver) ManagedChannelImpl.getNameResolver( - target, null, nameResolverFactory, NAMERESOLVER_ARGS)).getRetriedNameResolver(); + target, null, nameResolverRegistry, NAMERESOLVER_ARGS, + Collections.singleton(InetSocketAddress.class))).getRetriedNameResolver(); assertNotNull(nameResolver); assertEquals(expectedUri, nameResolver.uri); assertEquals(expectedUriString, nameResolver.uri.toString()); } private void testInvalidTarget(String target) { - NameResolver.Factory nameResolverFactory = new FakeNameResolverFactory("dns"); + NameResolverRegistry nameResolverRegistry = getTestRegistry("dns"); try { FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver( - target, null, nameResolverFactory, NAMERESOLVER_ARGS); + target, null, nameResolverRegistry, NAMERESOLVER_ARGS, + Collections.singleton(InetSocketAddress.class)); fail("Should have failed, but got resolver with " + nameResolver.uri); } catch (IllegalArgumentException e) { // expected } } - private static class FakeNameResolverFactory extends NameResolver.Factory { + private static NameResolverRegistry getTestRegistry(String expectedScheme) { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + FakeNameResolverProvider nameResolverProvider = new FakeNameResolverProvider(expectedScheme); + nameResolverRegistry.register(nameResolverProvider); + return nameResolverRegistry; + } + + private static class FakeNameResolverProvider extends NameResolverProvider { final String expectedScheme; - FakeNameResolverFactory(String expectedScheme) { + FakeNameResolverProvider(String expectedScheme) { this.expectedScheme = expectedScheme; } @@ -178,6 +235,16 @@ public class ManagedChannelImplGetNameResolverTest { public String getDefaultScheme() { return expectedScheme; } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } } private static class FakeNameResolver extends NameResolver { diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index faecfdfe5..e50eeaf76 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -66,6 +66,7 @@ import io.grpc.StringMarshaller; import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; import io.grpc.internal.TestUtils.MockClientTransportInfo; +import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.ArrayList; @@ -161,10 +162,14 @@ public class ManagedChannelImplIdlenessTest { when(mockNameResolverFactory .newNameResolver(any(URI.class), any(NameResolver.Args.class))) .thenReturn(mockNameResolver); + when(mockNameResolverFactory.getDefaultScheme()) + .thenReturn("mockscheme"); when(mockTransportFactory.getScheduledExecutorService()) .thenReturn(timer.getScheduledExecutorService()); + when(mockTransportFactory.getSupportedSocketAddressTypes()) + .thenReturn(Collections.singleton(InetSocketAddress.class)); - ManagedChannelImplBuilder builder = new ManagedChannelImplBuilder("fake://target", + ManagedChannelImplBuilder builder = new ManagedChannelImplBuilder("mockscheme:///target", new UnsupportedClientTransportFactoryBuilder(), null); builder diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index db794bf9e..dd7c7904a 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -122,6 +122,7 @@ import io.grpc.stub.ClientCalls; import io.grpc.testing.TestMethodDescriptors; import io.grpc.util.ForwardingSubchannel; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.ArrayList; @@ -287,6 +288,8 @@ public class ManagedChannelImplTest { ClientInterceptor... interceptors) { checkState(channel == null); + when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( + InetSocketAddress.class)); channel = new ManagedChannelImpl( channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Arrays.asList(interceptors), @@ -473,6 +476,8 @@ public class ManagedChannelImplTest { new FakeNameResolverFactory.Builder(expectedUri) .setServers(ImmutableList.of(addressGroup)).build(); channelBuilder.nameResolverFactory(nameResolverFactory); + when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( + InetSocketAddress.class)); channel = new ManagedChannelImpl( channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), @@ -535,6 +540,8 @@ public class ManagedChannelImplTest { new FakeNameResolverFactory.Builder(expectedUri) .setServers(ImmutableList.of(addressGroup)).build(); channelBuilder.nameResolverFactory(nameResolverFactory); + when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( + InetSocketAddress.class)); channel = new ManagedChannelImpl( channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), @@ -1718,7 +1725,7 @@ public class ManagedChannelImplTest { // Verify that resolving oob channel does not oob = helper.createResolvingOobChannelBuilder("oobauthority") .nameResolverFactory( - new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build()) + new FakeNameResolverFactory.Builder(URI.create("fake:///oobauthority")).build()) .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) .disableRetry() // irrelevant to what we test, disable retry to make verification easy @@ -2042,11 +2049,11 @@ public class ManagedChannelImplTest { } @Test - public void lbHelper_getNameResolverRegistry() { + public void lbHelper_getNonDefaultNameResolverRegistry() { createChannel(); assertThat(helper.getNameResolverRegistry()) - .isSameInstanceAs(NameResolverRegistry.getDefaultRegistry()); + .isNotSameInstanceAs(NameResolverRegistry.getDefaultRegistry()); } @Test @@ -2611,7 +2618,7 @@ public class ManagedChannelImplTest { } @Override public String getDefaultScheme() { - return "fakescheme"; + return "fake"; } }); createChannel(); @@ -3745,6 +3752,8 @@ public class ManagedChannelImplTest { } }, null); + when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( + InetSocketAddress.class)); customBuilder.executorPool = executorPool; customBuilder.channelz = channelz; ManagedChannel mychannel = customBuilder.nameResolverFactory(factory).build(); @@ -3825,7 +3834,7 @@ public class ManagedChannelImplTest { @Override public String getDefaultScheme() { - return "fakescheme"; + return "fake"; } }; channelBuilder.nameResolverFactory(factory).proxyDetector(neverProxy); diff --git a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java index 4558c6392..0d050a09a 100644 --- a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java +++ b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java @@ -48,6 +48,7 @@ import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; +import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.ArrayList; @@ -158,6 +159,8 @@ public class ServiceConfigErrorHandlingTest { private void createChannel(ClientInterceptor... interceptors) { checkState(channel == null); + when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( + InetSocketAddress.class)); channel = new ManagedChannelImpl( channelBuilder, diff --git a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java index 066992018..93413aa22 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java @@ -42,6 +42,8 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; @@ -283,6 +285,11 @@ public final class CronetChannelBuilder extends ForwardingChannelBuilder2<Cronet SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timeoutService); } } + + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } /** diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java index ce833d5c4..8ad292a3d 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java @@ -74,7 +74,7 @@ public final class GoogleCloudToProdNameResolverProvider extends NameResolverPro } @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } diff --git a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java index 3970c281e..8952ea1d8 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java @@ -94,7 +94,7 @@ final class SecretGrpclbNameResolverProvider { } @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } } diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index aa53dbc87..ccc176b61 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -33,6 +33,8 @@ import io.grpc.internal.ManagedChannelImplBuilder; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.SharedResourceHolder; import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; @@ -284,5 +286,10 @@ public final class InProcessChannelBuilder extends SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timerService); } } + + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return Collections.singleton(InProcessSocketAddress.class); + } } } diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessNameResolver.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessNameResolver.java new file mode 100644 index 000000000..f2e50eade --- /dev/null +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessNameResolver.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.inprocess; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Preconditions; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +final class InProcessNameResolver extends NameResolver { + private Listener2 listener; + private final String authority; + + InProcessNameResolver(String authority, String targetPath) { + checkArgument(authority == null, "non-null authority not supported"); + this.authority = targetPath; + } + + @Override + public String getServiceAuthority() { + return this.authority; + } + + @Override + public void start(Listener2 listener) { + Preconditions.checkState(this.listener == null, "already started"); + this.listener = checkNotNull(listener, "listener"); + resolve(); + } + + @Override + public void refresh() { + resolve(); + } + + private void resolve() { + ResolutionResult.Builder resolutionResultBuilder = ResolutionResult.newBuilder(); + List<EquivalentAddressGroup> servers = new ArrayList<>(1); + servers.add(new EquivalentAddressGroup(new InProcessSocketAddress(authority))); + resolutionResultBuilder.setAddresses(Collections.unmodifiableList(servers)); + listener.onResult(resolutionResultBuilder.build()); + } + + @Override + public void shutdown() {} +} diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessNameResolverProvider.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessNameResolverProvider.java new file mode 100644 index 000000000..98a37fc40 --- /dev/null +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessNameResolverProvider.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.inprocess; + +import com.google.common.base.Preconditions; +import io.grpc.Internal; +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; + +@Internal +public final class InProcessNameResolverProvider extends NameResolverProvider { + + private static final String SCHEME = "inprocess"; + + @Override + public InProcessNameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + return new InProcessNameResolver(targetUri.getAuthority(), getTargetPathFromUri(targetUri)); + } else { + return null; + } + } + + static String getTargetPathFromUri(URI targetUri) { + Preconditions.checkArgument(SCHEME.equals(targetUri.getScheme()), "scheme must be " + SCHEME); + String targetPath = targetUri.getPath(); + if (targetPath == null) { + targetPath = Preconditions.checkNotNull(targetUri.getSchemeSpecificPart(), "targetPath"); + } + return targetPath; + } + + @Override + public String getDefaultScheme() { + return SCHEME; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 3; + } + + @Override + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + return Collections.singleton(InProcessSocketAddress.class); + } +} diff --git a/inprocess/src/main/resources/META-INF/services/io.grpc.NameResolverProvider b/inprocess/src/main/resources/META-INF/services/io.grpc.NameResolverProvider new file mode 100644 index 000000000..a05425083 --- /dev/null +++ b/inprocess/src/main/resources/META-INF/services/io.grpc.NameResolverProvider @@ -0,0 +1 @@ +io.grpc.inprocess.InProcessNameResolverProvider diff --git a/inprocess/src/test/java/io/grpc/inprocess/InProcessNameResolverProviderTest.java b/inprocess/src/test/java/io/grpc/inprocess/InProcessNameResolverProviderTest.java new file mode 100644 index 000000000..0a2e85dcf --- /dev/null +++ b/inprocess/src/test/java/io/grpc/inprocess/InProcessNameResolverProviderTest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.inprocess; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.verify; + +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import java.net.SocketAddress; +import java.net.URI; +import java.util.List; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link InProcessNameResolverProvider}. */ +@RunWith(JUnit4.class) +public class InProcessNameResolverProviderTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private NameResolver.Listener2 mockListener; + + @Captor + private ArgumentCaptor<NameResolver.ResolutionResult> resultCaptor; + + InProcessNameResolverProvider inProcessNameResolverProvider = new InProcessNameResolverProvider(); + + + @Test + public void testRelativePath() { + InProcessNameResolver inProcessNameResolver = + inProcessNameResolverProvider.newNameResolver(URI.create("inprocess:proc.proc"), null); + assertThat(inProcessNameResolver).isNotNull(); + inProcessNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List<EquivalentAddressGroup> list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List<SocketAddress> addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(InProcessSocketAddress.class); + InProcessSocketAddress domainSocketAddress = (InProcessSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.getName()).isEqualTo("proc.proc"); + } + + @Test + public void testAbsolutePath() { + InProcessNameResolver inProcessNameResolver = + inProcessNameResolverProvider.newNameResolver(URI.create("inprocess:/proc.proc"), null); + assertThat(inProcessNameResolver).isNotNull(); + inProcessNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List<EquivalentAddressGroup> list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List<SocketAddress> addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(InProcessSocketAddress.class); + InProcessSocketAddress domainSocketAddress = (InProcessSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.getName()).isEqualTo("/proc.proc"); + } + + @Test + public void testAbsoluteAlternatePath() { + InProcessNameResolver udsNameResolver = + inProcessNameResolverProvider.newNameResolver(URI.create("inprocess:///proc.proc"), null); + assertThat(udsNameResolver).isNotNull(); + udsNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List<EquivalentAddressGroup> list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List<SocketAddress> addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(InProcessSocketAddress.class); + InProcessSocketAddress domainSocketAddress = (InProcessSocketAddress) addresses.get(0); + assertThat(domainSocketAddress.getName()).isEqualTo("/proc.proc"); + } + + @Test + public void testWrongScheme() { + assertNull(inProcessNameResolverProvider.newNameResolver(URI.create( + "badscheme://localhost/proc.proc"), null)); + } + + @Test + public void testPathWithAuthority() { + try { + inProcessNameResolverProvider.newNameResolver( + URI.create("inprocess://localhost/proc.proc"), null); + fail("exception expected"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo( + "non-null authority not supported"); + } + } +} diff --git a/inprocess/src/test/java/io/grpc/inprocess/InProcessNameResolverTest.java b/inprocess/src/test/java/io/grpc/inprocess/InProcessNameResolverTest.java new file mode 100644 index 000000000..b48aab123 --- /dev/null +++ b/inprocess/src/test/java/io/grpc/inprocess/InProcessNameResolverTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.inprocess; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.verify; + +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import java.net.SocketAddress; +import java.util.List; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link InProcessNameResolver}. */ +@RunWith(JUnit4.class) +public class InProcessNameResolverTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private NameResolver.Listener2 mockListener; + + @Captor + private ArgumentCaptor<NameResolver.ResolutionResult> resultCaptor; + + private InProcessNameResolver inProcessNameResolver; + + @Test + public void testValidTargetPath() { + inProcessNameResolver = new InProcessNameResolver(null, "proc.proc"); + inProcessNameResolver.start(mockListener); + verify(mockListener).onResult(resultCaptor.capture()); + NameResolver.ResolutionResult result = resultCaptor.getValue(); + List<EquivalentAddressGroup> list = result.getAddresses(); + assertThat(list).isNotNull(); + assertThat(list).hasSize(1); + EquivalentAddressGroup eag = list.get(0); + assertThat(eag).isNotNull(); + List<SocketAddress> addresses = eag.getAddresses(); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0)).isInstanceOf(InProcessSocketAddress.class); + InProcessSocketAddress socketAddress = (InProcessSocketAddress) addresses.get(0); + assertThat(socketAddress.getName()).isEqualTo("proc.proc"); + assertThat(inProcessNameResolver.getServiceAuthority()).isEqualTo("proc.proc"); + } + + @Test + public void testNonNullAuthority() { + try { + inProcessNameResolver = new InProcessNameResolver("authority", "proc.proc"); + fail("exception expected"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); + } + } +} diff --git a/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java index 9e63a3d9d..420a9c4a8 100644 --- a/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java +++ b/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java @@ -174,7 +174,7 @@ public class InProcessTransportTest extends AbstractTransportTest { fail("Call should fail."); } catch (ExecutionException ex) { StatusRuntimeException s = (StatusRuntimeException)ex.getCause(); - assertEquals(s.getStatus().getCode(), Code.UNIMPLEMENTED); + assertEquals(Code.UNIMPLEMENTED, s.getStatus().getCode()); } } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java index 778e8bdf9..6659af68a 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2NettyLocalChannelTest.java @@ -57,7 +57,7 @@ public class Http2NettyLocalChannelTest extends AbstractInteropTest { NettyChannelBuilder builder = NettyChannelBuilder .forAddress(new LocalAddress("in-process-1")) .negotiationType(NegotiationType.PLAINTEXT) - .channelType(LocalChannel.class) + .channelType(LocalChannel.class, LocalAddress.class) .eventLoopGroup(eventLoopGroup) .flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW) .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index 72ed8bf97..45ea303e5 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -191,7 +191,7 @@ public class RetryTest { rawServiceConfig.put("methodConfig", Arrays.<Object>asList(methodConfig)); channel = cleanupRule.register( NettyChannelBuilder.forAddress(localAddress) - .channelType(LocalChannel.class) + .channelType(LocalChannel.class, LocalAddress.class) .eventLoopGroup(group) .usePlaintext() .enableRetry() diff --git a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java index c5ad99181..1848b475d 100644 --- a/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/InternalNettyChannelBuilder.java @@ -23,6 +23,7 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourcePool; import io.grpc.internal.TransportTracer; import io.netty.channel.socket.nio.NioSocketChannel; +import java.net.InetSocketAddress; /** * Internal {@link NettyChannelBuilder} accessor. This is intended for usage internal to the gRPC @@ -100,7 +101,7 @@ public final class InternalNettyChannelBuilder { * io.netty.channel.EventLoopGroup}. */ public static void useNioTransport(NettyChannelBuilder builder) { - builder.channelType(NioSocketChannel.class); + builder.channelType(NioSocketChannel.class, InetSocketAddress.class); builder .eventLoopGroupPool(SharedResourcePool.forResource(Utils.NIO_WORKER_EVENT_LOOP_GROUP)); } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 138e11f6d..305ad1284 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -59,6 +59,8 @@ import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.ssl.SslContext; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Executor; @@ -116,6 +118,8 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh */ private final boolean useGetForSafeMethods = false; + private Class<? extends SocketAddress> transportSocketType = InetSocketAddress.class; + /** * Creates a new builder with the given server address. This factory method is primarily intended * for using Netty Channel types other than SocketChannel. {@link #forAddress(String, int)} should @@ -260,8 +264,23 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh */ @CanIgnoreReturnValue public NettyChannelBuilder channelType(Class<? extends Channel> channelType) { + return channelType(channelType, null); + } + + /** + * Similar to {@link #channelType(Class)} above but allows the + * caller to specify the socket-type associated with the channelType. + * + * @param channelType the type of {@link Channel} to use. + * @param transportSocketType the associated {@link SocketAddress} type. If {@code null}, then + * no compatibility check is performed between channel transport and name-resolver addresses. + */ + @CanIgnoreReturnValue + public NettyChannelBuilder channelType(Class<? extends Channel> channelType, + @Nullable Class<? extends SocketAddress> transportSocketType) { checkNotNull(channelType, "channelType"); - return channelFactory(new ReflectiveChannelFactory<>(channelType)); + return channelFactory(new ReflectiveChannelFactory<>(channelType), + transportSocketType); } /** @@ -279,7 +298,22 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh */ @CanIgnoreReturnValue public NettyChannelBuilder channelFactory(ChannelFactory<? extends Channel> channelFactory) { + return channelFactory(channelFactory, null); + } + + /** + * Similar to {@link #channelFactory(ChannelFactory)} above but allows the + * caller to specify the socket-type associated with the channelFactory. + * + * @param channelFactory the {@link ChannelFactory} to use. + * @param transportSocketType the associated {@link SocketAddress} type. If {@code null}, then + * no compatibility check is performed between channel transport and name-resolver addresses. + */ + @CanIgnoreReturnValue + public NettyChannelBuilder channelFactory(ChannelFactory<? extends Channel> channelFactory, + @Nullable Class<? extends SocketAddress> transportSocketType) { this.channelFactory = checkNotNull(channelFactory, "channelFactory"); + this.transportSocketType = transportSocketType; return this; } @@ -541,7 +575,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh negotiator, channelFactory, channelOptions, eventLoopGroupPool, autoFlowControl, flowControlWindow, maxInboundMessageSize, maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, - transportTracerFactory, localSocketPicker, useGetForSafeMethods); + transportTracerFactory, localSocketPicker, useGetForSafeMethods, transportSocketType); } @VisibleForTesting @@ -626,6 +660,10 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh return this; } + static Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + private final class DefaultProtocolNegotiator implements ProtocolNegotiator.ClientFactory { private NegotiationType negotiationType = NegotiationType.TLS; private SslContext sslContext; @@ -680,6 +718,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh private final boolean useGetForSafeMethods; private boolean closed; + private final Class<? extends SocketAddress> transportSocketType; NettyTransportFactory( ProtocolNegotiator protocolNegotiator, @@ -688,7 +727,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, TransportTracer.Factory transportTracerFactory, LocalSocketPicker localSocketPicker, - boolean useGetForSafeMethods) { + boolean useGetForSafeMethods, Class<? extends SocketAddress> transportSocketType) { this.protocolNegotiator = checkNotNull(protocolNegotiator, "protocolNegotiator"); this.channelFactory = channelFactory; this.channelOptions = new HashMap<ChannelOption<?>, Object>(channelOptions); @@ -706,6 +745,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh this.localSocketPicker = localSocketPicker != null ? localSocketPicker : new LocalSocketPicker(); this.useGetForSafeMethods = useGetForSafeMethods; + this.transportSocketType = transportSocketType; } @Override @@ -759,7 +799,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh result.negotiator.newNegotiator(), channelFactory, channelOptions, groupPool, autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory, localSocketPicker, - useGetForSafeMethods); + useGetForSafeMethods, transportSocketType); return new SwapChannelCredentialsResult(factory, result.callCredentials); } @@ -773,5 +813,11 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2<NettyCh protocolNegotiator.close(); groupPool.returnObject(group); } + + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return transportSocketType == null ? null + : Collections.singleton(transportSocketType); + } } } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java index 7cc77c150..1b22a95a4 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java @@ -19,10 +19,8 @@ package io.grpc.netty; import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; -import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Collection; -import java.util.Collections; /** Provider for {@link NettyChannelBuilder} instances. */ @Internal @@ -59,6 +57,6 @@ public final class NettyChannelProvider extends ManagedChannelProvider { @Override protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { - return Collections.singleton(InetSocketAddress.class); + return NettyChannelBuilder.getSupportedSocketAddressTypes(); } } diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java index ffc07ff6e..9f594193b 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java @@ -65,7 +65,7 @@ public final class UdsNameResolverProvider extends NameResolverProvider { } @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(DomainSocketAddress.class); } } diff --git a/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java index 59b50657a..4e9895da0 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java @@ -16,6 +16,7 @@ package io.grpc.netty; +import com.google.common.base.Preconditions; import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; @@ -51,11 +52,12 @@ public final class UdsNettyChannelProvider extends ManagedChannelProvider { @Override public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds) { + Preconditions.checkState(isAvailable()); NewChannelBuilderResult result = new NettyChannelProvider().newChannelBuilder(target, creds); if (result.getChannelBuilder() != null) { ((NettyChannelBuilder) result.getChannelBuilder()) .eventLoopGroupPool(SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP)) - .channelType(Utils.EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE); + .channelType(Utils.EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE, DomainSocketAddress.class); } return result; } diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index 032b04052..8a34a5d24 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -32,6 +32,7 @@ import io.grpc.netty.ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactor import io.netty.channel.Channel; import io.netty.channel.ChannelFactory; import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.handler.ssl.SslContext; import java.net.InetSocketAddress; @@ -80,9 +81,14 @@ public class NettyChannelBuilderTest { overrideAuthorityIsReadableHelper(builder, "override:5678"); } + private static SocketAddress getTestSocketAddress() { + return new InetSocketAddress("1.1.1.1", 80); + } + @Test public void overrideAuthorityIsReadableForSocketAddress() throws Exception { - NettyChannelBuilder builder = NettyChannelBuilder.forAddress(new SocketAddress(){}); + NettyChannelBuilder builder = NettyChannelBuilder.forAddress( + getTestSocketAddress()); overrideAuthorityIsReadableHelper(builder, "override:5678"); } @@ -99,7 +105,7 @@ public class NettyChannelBuilderTest { @Test public void failOverrideInvalidAuthority() { - NettyChannelBuilder builder = new NettyChannelBuilder(new SocketAddress(){}); + NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("Invalid authority:"); @@ -109,7 +115,7 @@ public class NettyChannelBuilderTest { @Test public void disableCheckAuthorityAllowsInvalidAuthority() { - NettyChannelBuilder builder = new NettyChannelBuilder(new SocketAddress(){}) + NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()) .disableCheckAuthority(); Object unused = builder.overrideAuthority("[invalidauthority") @@ -119,7 +125,7 @@ public class NettyChannelBuilderTest { @Test public void enableCheckAuthorityFailOverrideInvalidAuthority() { - NettyChannelBuilder builder = new NettyChannelBuilder(new SocketAddress(){}) + NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()) .disableCheckAuthority() .enableCheckAuthority(); @@ -139,14 +145,14 @@ public class NettyChannelBuilderTest { @Test public void sslContextCanBeNull() { - NettyChannelBuilder builder = new NettyChannelBuilder(new SocketAddress(){}); + NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); builder.sslContext(null); } @Test public void failIfSslContextIsNotClient() { SslContext sslContext = mock(SslContext.class); - NettyChannelBuilder builder = new NettyChannelBuilder(new SocketAddress(){}); + NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("Server SSL context can not be used for client channel"); @@ -168,7 +174,7 @@ public class NettyChannelBuilderTest { @Test public void failNegotiationTypeWithChannelCredentials_socketAddress() { NettyChannelBuilder builder = NettyChannelBuilder.forAddress( - new SocketAddress(){}, InsecureChannelCredentials.create()); + getTestSocketAddress(), InsecureChannelCredentials.create()); thrown.expect(IllegalStateException.class); thrown.expectMessage("Cannot change security when using ChannelCredentials"); @@ -265,7 +271,7 @@ public class NettyChannelBuilderTest { @Test public void assertEventLoopAndChannelType_onlyTypeProvided() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); - builder.channelType(LocalChannel.class); + builder.channelType(LocalChannel.class, LocalAddress.class); thrown.expect(IllegalStateException.class); thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); @@ -298,7 +304,7 @@ public class NettyChannelBuilderTest { public void assertEventLoopAndChannelType_bothProvided() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); builder.eventLoopGroup(mock(EventLoopGroup.class)); - builder.channelType(LocalChannel.class); + builder.channelType(LocalChannel.class, LocalAddress.class); builder.assertEventLoopAndChannelType(); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 8e9ed75bf..24e508363 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -60,6 +60,8 @@ import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Collections; import java.util.EnumSet; import java.util.Set; import java.util.concurrent.Executor; @@ -722,6 +724,10 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2<OkHttp return trustManagerFactory.getTrustManagers(); } + static Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + static final class SslSocketFactoryResult { /** {@code null} implies plaintext if {@code error == null}. */ public final SSLSocketFactory factory; @@ -898,5 +904,10 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2<OkHttp executorPool.returnObject(executor); scheduledExecutorServicePool.returnObject(scheduledExecutorService); } + + @Override + public Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { + return OkHttpChannelBuilder.getSupportedSocketAddressTypes(); + } } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java index 17a2512a6..bf2a9be6f 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java @@ -20,10 +20,8 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.InternalServiceProviders; import io.grpc.ManagedChannelProvider; -import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Collection; -import java.util.Collections; /** * Provider for {@link OkHttpChannelBuilder} instances. @@ -64,6 +62,6 @@ public final class OkHttpChannelProvider extends ManagedChannelProvider { @Override protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() { - return Collections.singleton(InetSocketAddress.class); + return OkHttpChannelBuilder.getSupportedSocketAddressTypes(); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index 4875a85ea..6b16c11fc 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -105,7 +105,7 @@ public final class XdsNameResolverProvider extends NameResolverProvider { } @Override - protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { + public Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } |